preserve default values for hidden states

This commit is contained in:
Mike J Innes 2016-10-25 19:10:26 +01:00
parent dea85df8b7
commit 1fde7b4615
2 changed files with 9 additions and 2 deletions

View File

@ -14,7 +14,9 @@ end
function makegraph(graph, args)
@assert length(args) == 1
mapconst(graph) do x
x == args[1] ? ModelInput(1) : x
x == args[1] ? ModelInput(1) :
isa(x, Delay) ? :(Delay($(Expr(:quote, x.name)), self.$(x.name))) :
x
end
end

View File

@ -1,7 +1,10 @@
type Delay
name::Symbol
default::Nullable{Param}
end
Delay(name) = Delay(name, nothing)
function liftloops!(ex, params)
e = Flow.normedges(ex)
hidden = intersect((b.args[1] for b in ex.args), params)
@ -42,13 +45,15 @@ end
function break!(g::IVertex)
g = bumpinputs(g)
loops = []
defaults = []
g = prewalk!(g) do v
isa(value(v), Delay) || return v
n = length(loops)+1
push!(loops, unroll!(v, n))
push!(defaults, get(value(v).default))
hinput(n)
end
cse(vertex(tuple, vertex(tuple, loops...), g))
cse(vertex(tuple, vertex(tuple, loops...), g)), defaults
end
# r = Recurrent(10, 10)