preserve default values for hidden states
This commit is contained in:
parent
dea85df8b7
commit
1fde7b4615
|
@ -14,7 +14,9 @@ end
|
|||
function makegraph(graph, args)
|
||||
@assert length(args) == 1
|
||||
mapconst(graph) do x
|
||||
x == args[1] ? ModelInput(1) : x
|
||||
x == args[1] ? ModelInput(1) :
|
||||
isa(x, Delay) ? :(Delay($(Expr(:quote, x.name)), self.$(x.name))) :
|
||||
x
|
||||
end
|
||||
end
|
||||
|
||||
|
|
|
@ -1,7 +1,10 @@
|
|||
type Delay
|
||||
name::Symbol
|
||||
default::Nullable{Param}
|
||||
end
|
||||
|
||||
Delay(name) = Delay(name, nothing)
|
||||
|
||||
function liftloops!(ex, params)
|
||||
e = Flow.normedges(ex)
|
||||
hidden = intersect((b.args[1] for b in ex.args), params)
|
||||
|
@ -42,13 +45,15 @@ end
|
|||
function break!(g::IVertex)
|
||||
g = bumpinputs(g)
|
||||
loops = []
|
||||
defaults = []
|
||||
g = prewalk!(g) do v
|
||||
isa(value(v), Delay) || return v
|
||||
n = length(loops)+1
|
||||
push!(loops, unroll!(v, n))
|
||||
push!(defaults, get(value(v).default))
|
||||
hinput(n)
|
||||
end
|
||||
cse(vertex(tuple, vertex(tuple, loops...), g))
|
||||
cse(vertex(tuple, vertex(tuple, loops...), g)), defaults
|
||||
end
|
||||
|
||||
# r = Recurrent(10, 10)
|
||||
|
|
Loading…
Reference in New Issue