fix unrolling
This commit is contained in:
parent
ba60c4596b
commit
42b50c976a
|
@ -1,8 +1,9 @@
|
|||
module Flux
|
||||
|
||||
using MacroTools, Lazy, Flow, Juno
|
||||
import Flow: graphm, syntax, prewalk, postwalk, iscyclic, Constant, constant,
|
||||
isconstant, value, inputs, thread!, value, inputs, Split, Group
|
||||
import Flow: graphm, syntax, prewalk!, prewalk, postwalk, iscyclic,
|
||||
Constant, constant, isconstant, value, inputs, thread!, value, inputs,
|
||||
Split, Group, group
|
||||
import Juno: Tree, Row
|
||||
|
||||
# Zero Flux Given
|
||||
|
|
|
@ -53,22 +53,24 @@ function break!(g::IVertex)
|
|||
push!(defaults, get(value(v).default))
|
||||
hiddeninput(n)
|
||||
end
|
||||
cse(vertex(tuple, vertex(tuple, loops...), g)), defaults
|
||||
cse(group(group(loops...), g)), defaults
|
||||
end
|
||||
|
||||
# function unroll(model, n)
|
||||
# graph, defaults = break!(atomise(model))
|
||||
# outputs = [spliceinputs(graph, vertex(tuple, map(constant, defaults)...), inputnode(1))]
|
||||
# for i = 2:n
|
||||
# push!(outputs, spliceinputs(graph, outputs[end][1], constant(ModelInput(i))))
|
||||
# end
|
||||
# state = outputs[end][1]
|
||||
# outputs = map(x -> x[2], outputs)
|
||||
# vertex(tuple, state, vertex(tuple, outputs...))
|
||||
# end
|
||||
function unroll(model, n)
|
||||
graph, defaults = break!(atomise(model))
|
||||
outputs = [spliceinputs(graph, group(map(constant, defaults)...), inputnode(1))]
|
||||
detuple(outputs[end])
|
||||
for i = 2:n
|
||||
push!(outputs, spliceinputs(graph, outputs[end][1], inputnode(i)))
|
||||
end
|
||||
state = outputs[end][1]
|
||||
outputs = map(x -> x[2], outputs)
|
||||
@> group(state, group(outputs...)) detuple
|
||||
end
|
||||
|
||||
# syntax′(x) = syntax(Flow.dl(x), bindconst = true)
|
||||
# r = Recurrent(10,10)
|
||||
# unroll(r, 1) |> syntax |> prettify |> display
|
||||
# unroll(r,5) |> cse |> syntax′ |> prettify |> display
|
||||
|
||||
@net type Recurrent
|
||||
Wx; Wh; B
|
||||
|
|
Loading…
Reference in New Issue