fix unrolling

This commit is contained in:
Mike J Innes 2016-10-26 00:49:32 +01:00
parent ba60c4596b
commit 42b50c976a
2 changed files with 17 additions and 14 deletions

View File

@ -1,8 +1,9 @@
module Flux
using MacroTools, Lazy, Flow, Juno
import Flow: graphm, syntax, prewalk, postwalk, iscyclic, Constant, constant,
isconstant, value, inputs, thread!, value, inputs, Split, Group
import Flow: graphm, syntax, prewalk!, prewalk, postwalk, iscyclic,
Constant, constant, isconstant, value, inputs, thread!, value, inputs,
Split, Group, group
import Juno: Tree, Row
# Zero Flux Given

View File

@ -53,22 +53,24 @@ function break!(g::IVertex)
push!(defaults, get(value(v).default))
hiddeninput(n)
end
cse(vertex(tuple, vertex(tuple, loops...), g)), defaults
cse(group(group(loops...), g)), defaults
end
# function unroll(model, n)
# graph, defaults = break!(atomise(model))
# outputs = [spliceinputs(graph, vertex(tuple, map(constant, defaults)...), inputnode(1))]
# for i = 2:n
# push!(outputs, spliceinputs(graph, outputs[end][1], constant(ModelInput(i))))
# end
# state = outputs[end][1]
# outputs = map(x -> x[2], outputs)
# vertex(tuple, state, vertex(tuple, outputs...))
# end
function unroll(model, n)
graph, defaults = break!(atomise(model))
outputs = [spliceinputs(graph, group(map(constant, defaults)...), inputnode(1))]
detuple(outputs[end])
for i = 2:n
push!(outputs, spliceinputs(graph, outputs[end][1], inputnode(i)))
end
state = outputs[end][1]
outputs = map(x -> x[2], outputs)
@> group(state, group(outputs...)) detuple
end
# syntax(x) = syntax(Flow.dl(x), bindconst = true)
# r = Recurrent(10,10)
# unroll(r, 1) |> syntax |> prettify |> display
# unroll(r,5) |> cse |> syntax |> prettify |> display
@net type Recurrent
Wx; Wh; B