diff --git a/src/Flux.jl b/src/Flux.jl index 9c408768..53a97f21 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -1,9 +1,9 @@ module Flux using MacroTools, Lazy, DataFlow, Juno -import DataFlow: graphm, syntax, prewalk!, prewalk, postwalk, iscyclic, - Constant, constant, isconstant, value, inputs, thread!, value, inputs, - Split, Group, group +import DataFlow: graphm, syntax, prewalk!, postwalk!, prewalk, postwalk, + iscyclic, Constant, constant, isconstant, Group, group, value, inputs, + thread!, value, inputs, Split import Juno: Tree, Row # Zero Flux Given diff --git a/src/compiler/loops.jl b/src/compiler/loops.jl index 66084a54..3686ce60 100644 --- a/src/compiler/loops.jl +++ b/src/compiler/loops.jl @@ -40,10 +40,10 @@ function collect_state(v::IVertex) isa(value(v), Offset) || return v if (i = findfirst(state, v[1])) == 0 push!(state, v[1]) - push!(offset, value(v).n) + push!(offset, max(0, -value(v).n)) push!(default, get(value(v).default)) else - offset[i] = min(offset[i], value(v).n) + offset[i] = max(offset[i], -value(v).n) end v end @@ -56,29 +56,43 @@ function create_steps(v::IVertex, n) [bumpinputs(spliceinputs(v, hiddeninput(i))) for i = 1:n] end +function getvar(n, step, steps, offset, default) + if step < 1 + hiddeninput(sum(offset[1:n-1]) + 1 - step) + elseif step > length(steps) + constant(default[n]) + else + steps[step][1,n] + end +end + +function stateout(steps, offset, default) + outs = [] + defaults = [] + for i = 1:length(offset), j = 1:offset[i] + push!(outs, getvar(i, length(steps)-j+1, steps, offset, default)) + push!(defaults, default[i]) + end + group(outs...), defaults +end + function unrollgraph(v::IVertex, n) state, offset, default = collect_state(v) v = group(group(state...), v) steps = create_steps(v, n) for i = 1:n vars = inputs(steps[i][1]) - prewalk!(steps[i]) do v + postwalk!(steps[i]) do v isa(value(v), Offset) || return v - stepid = value(v).n + i varid = findfirst(vars,v[1]) - if stepid ∈ 1:n - steps[stepid][1,varid] - elseif stepid < 1 - vertex(:input, constant(varid)) - elseif stepid > n - constant(:output, constant(varid)) - end + getvar(varid, value(v).n + i, steps, offset, default) end end - group(steps[end][1],group(map(x->x[2], steps)...)) + state, defaults = stateout(steps, offset, default) + group(state,group(map(x->x[2], steps)...)), map(Flux.state, defaults) end -unrollgraph(atomise(Chain(r,r)), 5) |> detuple |> syntax |> prettify +unrollgraph(m, n) = unrollgraph(atomise(m), n) type Unrolled <: Model model @@ -90,12 +104,3 @@ end graph(u::Unrolled) = u.graph unroll(model, n) = Unrolled(model, unrollgraph(model, n)..., n) - -@net type Recurrent - y - function (x) - y = σ(x, y{-1}) - end -end - -r = Recurrent(rand(5))