initial work on loops
This commit is contained in:
parent
3b38b8696d
commit
96102f0130
@ -14,6 +14,7 @@ back!(m::Model, ∇) = error("Backprop not implemented for $(typeof(m))")
|
||||
update!(m::Model, η) = m
|
||||
|
||||
include("compiler/diff.jl")
|
||||
include("compiler/loop.jl")
|
||||
include("compiler/code.jl")
|
||||
|
||||
include("cost.jl")
|
||||
|
@ -27,6 +27,7 @@ function build_type(T, params, temps)
|
||||
end
|
||||
|
||||
function build_forward(body, temps)
|
||||
body = cut_forward(body)
|
||||
forward = IVertex{Any}(Flow.Do())
|
||||
for (ex, k) in temps
|
||||
k = Expr(:quote, k)
|
||||
@ -75,8 +76,19 @@ function process_type(ex)
|
||||
end |> longdef |> MacroTools.flatten
|
||||
end
|
||||
|
||||
process_type(:(type Sigmoid
|
||||
W
|
||||
b
|
||||
x -> σ(W*x+b)
|
||||
# process_type(:(type Sigmoid
|
||||
# W
|
||||
# b
|
||||
# bp
|
||||
# x -> σ(W*x+b)
|
||||
# end)) |> prettify
|
||||
|
||||
process_type(:(type Recurrent
|
||||
Wxh; Whh; Bh
|
||||
Why; By
|
||||
|
||||
function (x)
|
||||
hidden = σ( Wxh*x + Whh*Delay(hidden) + Bh )
|
||||
y = σ( Why*hidden + By )
|
||||
end
|
||||
end)) |> prettify
|
||||
|
16
src/compiler/loop.jl
Normal file
16
src/compiler/loop.jl
Normal file
@ -0,0 +1,16 @@
|
||||
function cut_forward(v::IVertex)
|
||||
pushes = []
|
||||
Flow.prefor(v) do w
|
||||
value(w) == :Delay &&
|
||||
push!(pushes, vertex(:push!, vertex(:(self.delay)), w[1]))
|
||||
end
|
||||
isempty(pushes) && return v
|
||||
@assert length(pushes) == 1
|
||||
v = vertex(Flow.Do(), pushes..., v)
|
||||
prewalk(v) do v
|
||||
value(v) == :Delay || return v
|
||||
il(@flow(pop!(self.delay)))
|
||||
end
|
||||
end
|
||||
|
||||
cut_forward(g)
|
Loading…
Reference in New Issue
Block a user