remove loop compilation
This commit is contained in:
parent
e5856d8b27
commit
3de0bc4dec
|
@ -15,7 +15,6 @@ update!(m::Model, η) = m
|
|||
include("capacitor.jl")
|
||||
|
||||
include("compiler/diff.jl")
|
||||
include("compiler/loop.jl")
|
||||
include("compiler/code.jl")
|
||||
|
||||
include("cost.jl")
|
||||
|
|
|
@ -17,27 +17,20 @@ function build_type(T, params)
|
|||
end
|
||||
|
||||
function build_forward(body, args)
|
||||
body = cut_forward(body, args)
|
||||
cse(body)
|
||||
end
|
||||
|
||||
function build_backward(body, x, params)
|
||||
Δs, Δloops = cut_backward(body, [x])
|
||||
Δs = invert(body)
|
||||
back = IVertex{Any}(Flow.Do())
|
||||
for param in params
|
||||
haskey(Δs, :(self.$param)) || continue
|
||||
k = symbol("Δ", param)
|
||||
ksym = Expr(:quote, k)
|
||||
ex = Δs[:(self.$param)]
|
||||
for Δloop in Δloops
|
||||
ex = addΔ(ex, get(Δloop, :(self.$param), vertex(0)))
|
||||
end
|
||||
thread!(back, @v(setfield!(:self, ksym, :(self.$k) + ex)))
|
||||
end
|
||||
ex = Δs[x]
|
||||
for Δloop in Δloops
|
||||
ex = addΔ(ex, get(Δloop, x, vertex(0)))
|
||||
end
|
||||
thread!(back, @flow(tuple($ex)))
|
||||
cse(back)
|
||||
end
|
||||
|
@ -69,16 +62,5 @@ end
|
|||
# process_type(:(type Sigmoid
|
||||
# W
|
||||
# b
|
||||
# bp
|
||||
# x -> σ(W*x+b)
|
||||
# end)) |> prettify
|
||||
|
||||
process_type(:(type Recurrent
|
||||
Wxh; Whh; Bh
|
||||
Why; By
|
||||
|
||||
function (x)
|
||||
hidden = σ( Wxh*x + Whh*Delay(hidden) + Bh )
|
||||
y = σ( Why*hidden + By )
|
||||
end
|
||||
end)) |> prettify
|
||||
|
|
|
@ -1,59 +0,0 @@
|
|||
function delays(v::IVertex)
|
||||
ds = []
|
||||
Flow.prefor(v) do w
|
||||
value(w) == :Delay &&
|
||||
push!(ds, w)
|
||||
end
|
||||
return ds
|
||||
end
|
||||
|
||||
function cut(v::IVertex, f = _ -> il(@flow(last(self.delay))))
|
||||
prewalk(v) do v
|
||||
value(v) == :Delay ? f(v) : v
|
||||
end
|
||||
end
|
||||
|
||||
replaceall(d::Dict, args...) = Dict(k => replace(v, args...) for (k, v) in d)
|
||||
|
||||
# Create the forward function; a single delay node becomes an
|
||||
# input and an output node.
|
||||
function cut_forward(v::IVertex, params, ds = delays(v))
|
||||
pushes = map(x->vertex(:push!, vertex(:(self.delay)), x[1], map(vertex, params)...), ds)
|
||||
isempty(pushes) && return v
|
||||
@assert length(pushes) == 1
|
||||
v = vertex(Flow.Do(), pushes..., v)
|
||||
cut(v)
|
||||
end
|
||||
|
||||
# Given a delay node, give the parameter gradients with respect to
|
||||
# the node and a function which will propagate gradients around
|
||||
# the loop.
|
||||
function invertloop(v::IVertex, params)
|
||||
@gensym input
|
||||
v = cut(v[1], v -> vertex(input))
|
||||
Δs = invert(v, @flow(Δloop))
|
||||
Δs = replaceall(Δs, vertex(input), il(@flow(last(self.delay))))
|
||||
Δs, :((Δ, $input, $(params...)) -> $(syntax(cse(Δs[input]))))
|
||||
end
|
||||
|
||||
# Returns:
|
||||
# Parameter gradients with respect to the function
|
||||
# Parameter gradients with respect to each delay node
|
||||
function cut_backward(v::IVertex, params, ds = delays(v))
|
||||
isempty(ds) && return invert(v), []
|
||||
@assert length(ds) == 1
|
||||
@gensym input
|
||||
Δs = invert(cut(v, _ -> vertex(input)))
|
||||
Δs = replaceall(Δs, vertex(input), il(@flow(last(self.delay))))
|
||||
Δloop, ∇loop = invertloop(ds[1], params)
|
||||
Δh = vertex(:back!, vertex(:(self.delay)), Δs[input], vertex(∇loop))
|
||||
Δloop = replaceall(Δloop, vertex(:Δloop), Δh)
|
||||
Δs, [Δloop]
|
||||
end
|
||||
|
||||
# g = il(@flow begin
|
||||
# hidden = σ( Wxh*x + Whh*Delay(hidden) + bh )
|
||||
# y = σ( Why*hidden + by )
|
||||
# end)
|
||||
|
||||
# cut_backward(g, [:x])[1]
|
Loading…
Reference in New Issue