recurrence semi-working
This commit is contained in:
parent
5c018bcae7
commit
1480e1fab6
@ -1,53 +1,44 @@
|
||||
function forward_temporaries(body, Δs)
|
||||
# exs = union((common(body, Δ) for Δ in values(Δs))...)
|
||||
# filter!(ex -> !(@capture(value(ex), self._) || isconstant(ex)), exs)
|
||||
# [ex=>symbol("temp", i) for (i, ex) in enumerate(exs)]
|
||||
return Dict()
|
||||
end
|
||||
|
||||
function process_func(ex, params)
|
||||
@capture(shortdef(ex), (args__,) -> body_)
|
||||
body = il(graphm(body))
|
||||
body = map(x -> x in params ? :(self.$x) : x, body)
|
||||
Δ = invert(body, @flow(Δ))
|
||||
return args, body, Δ
|
||||
return args, body
|
||||
end
|
||||
|
||||
function build_type(T, params, temps)
|
||||
function build_type(T, params)
|
||||
quote
|
||||
type $T
|
||||
$(params...)
|
||||
$([symbol("Δ", s) for s in params]...)
|
||||
$(temps...)
|
||||
end
|
||||
$T($(params...)) = $T($(params...),
|
||||
$((:(zeros($p)) for p in params)...),
|
||||
$((:nothing for t in temps)...))
|
||||
$((:(zeros($p)) for p in params)...))
|
||||
end
|
||||
end
|
||||
|
||||
function build_forward(body, temps)
|
||||
body = cut_forward(body)
|
||||
forward = IVertex{Any}(Flow.Do())
|
||||
for (ex, k) in temps
|
||||
k = Expr(:quote, k)
|
||||
thread!(forward, @v(setfield!(:self, k, ex)))
|
||||
end
|
||||
thread!(forward, body)
|
||||
cse(forward)
|
||||
function build_forward(body, args)
|
||||
body = cut_forward(body, args)
|
||||
cse(body)
|
||||
end
|
||||
|
||||
function build_backward(Δs, x, params, temps)
|
||||
function build_backward(body, x, params)
|
||||
Δs, Δloops = cut_backward(body, [x])
|
||||
back = IVertex{Any}(Flow.Do())
|
||||
tempify(v) = prewalk(v -> haskey(temps, v) ? @v(:(self.$(temps[v]))) : v, v)
|
||||
for param in params
|
||||
haskey(Δs, :(self.$param)) || continue
|
||||
k = symbol("Δ", param)
|
||||
ksym = Expr(:quote, k)
|
||||
ex = tempify(Δs[:(self.$param)])
|
||||
ex = Δs[:(self.$param)]
|
||||
for Δloop in Δloops
|
||||
ex = addΔ(ex, get(Δloop, :(self.$param), vertex(0)))
|
||||
end
|
||||
thread!(back, @v(setfield!(:self, ksym, :(self.$k) + ex)))
|
||||
end
|
||||
thread!(back, @flow(tuple($(tempify(Δs[x])))))
|
||||
ex = Δs[x]
|
||||
for Δloop in Δloops
|
||||
ex = addΔ(ex, get(Δloop, x, vertex(0)))
|
||||
end
|
||||
thread!(back, @flow(tuple($ex)))
|
||||
cse(back)
|
||||
end
|
||||
|
||||
@ -65,13 +56,12 @@ function process_type(ex)
|
||||
@destruct [params = true || [],
|
||||
funcs = false || []] = groupby(x->isa(x, Symbol), fs)
|
||||
@assert length(funcs) == 1
|
||||
args, body, Δs = process_func(funcs[1], params)
|
||||
args, body = process_func(funcs[1], params)
|
||||
@assert length(args) == 1
|
||||
temps = forward_temporaries(body, Δs)
|
||||
quote
|
||||
$(build_type(T, params, collect(values(temps))))
|
||||
(self::$T)($(args...),) = $(syntax(build_forward(body, temps)))
|
||||
back!(self::$T, Δ, $(args...)) = $(syntax(build_backward(Δs, args[1], params, temps)))
|
||||
$(build_type(T, params))
|
||||
(self::$T)($(args...),) = $(syntax(build_forward(body, args)))
|
||||
back!(self::$T, Δ, $(args...)) = $(syntax(build_backward(body, args[1], params)))
|
||||
$(build_update(T, params))
|
||||
end |> longdef
|
||||
end
|
||||
|
@ -2,6 +2,8 @@ import Flow: isconstant, il, dl, cse, prewalk, graphm, syntax, @v
|
||||
|
||||
vertex(a...) = IVertex{Any}(a...)
|
||||
|
||||
addΔ(a, b) = vertex(:+, a, b)
|
||||
|
||||
# Special case a couple of operators to clean up output code
|
||||
const symbolic = Dict()
|
||||
|
||||
@ -13,7 +15,7 @@ function ∇v(v::Vertex, Δ)
|
||||
map(i -> @flow(getindex($Δ, $i)), 1:Flow.nin(v))
|
||||
end
|
||||
|
||||
function invert(v::IVertex, Δ, out = d())
|
||||
function invert(v::IVertex, Δ = vertex(:Δ), out = d())
|
||||
@assert !iscyclic(v)
|
||||
if isconstant(v)
|
||||
@assert !haskey(out, value(v))
|
||||
|
@ -7,13 +7,53 @@ function delays(v::IVertex)
|
||||
return ds
|
||||
end
|
||||
|
||||
function cut_forward(v::IVertex)
|
||||
pushes = map(x->vertex(:push!, vertex(:(self.delay)), v[1]), delays(v))
|
||||
function cut(v::IVertex, f = _ -> il(@flow(last(self.delay))))
|
||||
prewalk(v) do v
|
||||
value(v) == :Delay ? f(v) : v
|
||||
end
|
||||
end
|
||||
|
||||
replaceall(d::Dict, args...) = Dict(k => replace(v, args...) for (k, v) in d)
|
||||
|
||||
# Create the forward function; a single delay node becomes an
|
||||
# input and an output node.
|
||||
function cut_forward(v::IVertex, params, ds = delays(v))
|
||||
pushes = map(x->vertex(:push!, vertex(:(self.delay)), x[1], map(vertex, params)...), ds)
|
||||
isempty(pushes) && return v
|
||||
@assert length(pushes) == 1
|
||||
v = vertex(Flow.Do(), pushes..., v)
|
||||
prewalk(v) do v
|
||||
value(v) == :Delay || return v
|
||||
il(@flow(pop!(self.delay)))
|
||||
end
|
||||
cut(v)
|
||||
end
|
||||
|
||||
# Given a delay node, give the parameter gradients with respect to
|
||||
# the node and a function which will propagate gradients around
|
||||
# the loop.
|
||||
function invertloop(v::IVertex, params)
|
||||
@gensym input
|
||||
v = cut(v[1], v -> vertex(input))
|
||||
Δs = invert(v, @flow(Δloop))
|
||||
Δs = replaceall(Δs, vertex(input), il(@flow(last(self.delay))))
|
||||
Δs, :((Δ, $input, $(params...)) -> $(syntax(cse(Δs[input]))))
|
||||
end
|
||||
|
||||
# Returns:
|
||||
# Parameter gradients with respect to the function
|
||||
# Parameter gradients with respect to each delay node
|
||||
function cut_backward(v::IVertex, params, ds = delays(v))
|
||||
isempty(ds) && return invert(v), []
|
||||
@assert length(ds) == 1
|
||||
@gensym input
|
||||
Δs = invert(cut(v, _ -> vertex(input)))
|
||||
Δs = replaceall(Δs, vertex(input), il(@flow(last(self.delay))))
|
||||
Δloop, ∇loop = invertloop(ds[1], params)
|
||||
Δh = vertex(:back!, vertex(:(self.delay)), Δs[input], vertex(∇loop))
|
||||
Δloop = replaceall(Δloop, vertex(:Δloop), Δh)
|
||||
Δs, [Δloop]
|
||||
end
|
||||
|
||||
# g = il(@flow begin
|
||||
# hidden = σ( Wxh*x + Whh*Delay(hidden) + bh )
|
||||
# y = σ( Why*hidden + by )
|
||||
# end)
|
||||
|
||||
# cut_backward(g, [:x])[1]
|
||||
|
Loading…
Reference in New Issue
Block a user