initial work on loops

This commit is contained in:
Mike J Innes 2016-06-07 16:01:48 +01:00
parent 3b38b8696d
commit 96102f0130
3 changed files with 33 additions and 4 deletions

View File

@ -14,6 +14,7 @@ back!(m::Model, ∇) = error("Backprop not implemented for $(typeof(m))")
update!(m::Model, η) = m update!(m::Model, η) = m
include("compiler/diff.jl") include("compiler/diff.jl")
include("compiler/loop.jl")
include("compiler/code.jl") include("compiler/code.jl")
include("cost.jl") include("cost.jl")

View File

@ -27,6 +27,7 @@ function build_type(T, params, temps)
end end
function build_forward(body, temps) function build_forward(body, temps)
body = cut_forward(body)
forward = IVertex{Any}(Flow.Do()) forward = IVertex{Any}(Flow.Do())
for (ex, k) in temps for (ex, k) in temps
k = Expr(:quote, k) k = Expr(:quote, k)
@ -75,8 +76,19 @@ function process_type(ex)
end |> longdef |> MacroTools.flatten end |> longdef |> MacroTools.flatten
end end
process_type(:(type Sigmoid # process_type(:(type Sigmoid
W # W
b # b
x -> σ(W*x+b) # bp
# x -> σ(W*x+b)
# end)) |> prettify
process_type(:(type Recurrent
Wxh; Whh; Bh
Why; By
function (x)
hidden = σ( Wxh*x + Whh*Delay(hidden) + Bh )
y = σ( Why*hidden + By )
end
end)) |> prettify end)) |> prettify

16
src/compiler/loop.jl Normal file
View File

@ -0,0 +1,16 @@
function cut_forward(v::IVertex)
pushes = []
Flow.prefor(v) do w
value(w) == :Delay &&
push!(pushes, vertex(:push!, vertex(:(self.delay)), w[1]))
end
isempty(pushes) && return v
@assert length(pushes) == 1
v = vertex(Flow.Do(), pushes..., v)
prewalk(v) do v
value(v) == :Delay || return v
il(@flow(pop!(self.delay)))
end
end
cut_forward(g)