diff --git a/src/Flux.jl b/src/Flux.jl index bc2d82a6..610f13db 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -14,6 +14,7 @@ back!(m::Model, ∇) = error("Backprop not implemented for $(typeof(m))") update!(m::Model, η) = m include("compiler/diff.jl") +include("compiler/loop.jl") include("compiler/code.jl") include("cost.jl") diff --git a/src/compiler/code.jl b/src/compiler/code.jl index 0f517815..66a9aba5 100644 --- a/src/compiler/code.jl +++ b/src/compiler/code.jl @@ -27,6 +27,7 @@ function build_type(T, params, temps) end function build_forward(body, temps) + body = cut_forward(body) forward = IVertex{Any}(Flow.Do()) for (ex, k) in temps k = Expr(:quote, k) @@ -75,8 +76,19 @@ function process_type(ex) end |> longdef |> MacroTools.flatten end -process_type(:(type Sigmoid - W - b - x -> σ(W*x+b) +# process_type(:(type Sigmoid +# W +# b +# bp +# x -> σ(W*x+b) +# end)) |> prettify + +process_type(:(type Recurrent + Wxh; Whh; Bh + Why; By + + function (x) + hidden = σ( Wxh*x + Whh*Delay(hidden) + Bh ) + y = σ( Why*hidden + By ) + end end)) |> prettify diff --git a/src/compiler/loop.jl b/src/compiler/loop.jl new file mode 100644 index 00000000..2109f0d0 --- /dev/null +++ b/src/compiler/loop.jl @@ -0,0 +1,16 @@ +function cut_forward(v::IVertex) + pushes = [] + Flow.prefor(v) do w + value(w) == :Delay && + push!(pushes, vertex(:push!, vertex(:(self.delay)), w[1])) + end + isempty(pushes) && return v + @assert length(pushes) == 1 + v = vertex(Flow.Do(), pushes..., v) + prewalk(v) do v + value(v) == :Delay || return v + il(@flow(pop!(self.delay))) + end +end + +cut_forward(g)