Flux.jl/src/compiler/loops.jl
2016-10-25 21:10:04 +01:00

87 lines
2.0 KiB
Julia
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

type Delay
name::Symbol
default::Nullable{Param}
end
Delay(name) = Delay(name, nothing)
function liftloops!(ex, params)
e = Flow.normedges(ex)
hidden = intersect((b.args[1] for b in ex.args), params)
edges = Dict(h => gensym("edge") for h in hidden)
for b in ex.args
b.args[2] = MacroTools.postwalk(x -> get(edges, x, x), b.args[2])
end
for (h, e) in edges
unshift!(ex.args, :($e = $(Delay(h))($h)))
end
return ex
end
function hasloops(model)
g = graph(model)
g == nothing && return false
iscyclic(g) && return true
result = false
map(m -> hasloops(m) && (result = true), g)
return result
end
function atomise(model)
postwalk(graph(model)) do v
hasloops(value(v)) || return v
spliceinputs(atomise(value(v)), inputs(v)...)
end
end
hinput(n) = vertex(getindex, constant(ModelInput(1)), constant(n))
function unroll!(delay::IVertex, n)
prewalk!(delay[1]) do v
v === delay ? hinput(n) : v
end
end
function break!(g::IVertex)
g = bumpinputs(g)
loops = []
defaults = []
g = prewalk!(g) do v
isa(value(v), Delay) || return v
n = length(loops)+1
push!(loops, unroll!(v, n))
push!(defaults, get(value(v).default))
hinput(n)
end
cse(vertex(tuple, vertex(tuple, loops...), g)), defaults
end
function unroll(model, n)
graph, defaults = break!(atomise(model))
outputs = [spliceinputs(graph, vertex(tuple, map(constant, defaults)...), constant(ModelInput(1)))]
for i = 2:n
push!(outputs, spliceinputs(graph, outputs[end][1], constant(ModelInput(i))))
end
state = outputs[end][1]
outputs = map(x -> x[2], outputs)
vertex(tuple, state, vertex(tuple, outputs...))
end
# r = Chain(Recurrent(30,10), Recurrent(10, 20))
# unroll(r, 1) |> syntax |> prettify |> display
@net type Recurrent
Wx; Wh; B
hidden
function (x)
hidden = σ( Wx*x + Wh*hidden + B )
end
end
Recurrent(in::Integer, out::Integer; init = initn) =
Recurrent(init(out, in), init(out, out), init(out), zeros(out))
Base.show(io::IO, r::Recurrent) =
print(io, "Flux.Recurrent(...)")