recurrence semi-working
This commit is contained in:
parent
5c018bcae7
commit
1480e1fab6
@ -1,53 +1,44 @@
|
|||||||
function forward_temporaries(body, Δs)
|
|
||||||
# exs = union((common(body, Δ) for Δ in values(Δs))...)
|
|
||||||
# filter!(ex -> !(@capture(value(ex), self._) || isconstant(ex)), exs)
|
|
||||||
# [ex=>symbol("temp", i) for (i, ex) in enumerate(exs)]
|
|
||||||
return Dict()
|
|
||||||
end
|
|
||||||
|
|
||||||
function process_func(ex, params)
|
function process_func(ex, params)
|
||||||
@capture(shortdef(ex), (args__,) -> body_)
|
@capture(shortdef(ex), (args__,) -> body_)
|
||||||
body = il(graphm(body))
|
body = il(graphm(body))
|
||||||
body = map(x -> x in params ? :(self.$x) : x, body)
|
body = map(x -> x in params ? :(self.$x) : x, body)
|
||||||
Δ = invert(body, @flow(Δ))
|
return args, body
|
||||||
return args, body, Δ
|
|
||||||
end
|
end
|
||||||
|
|
||||||
function build_type(T, params, temps)
|
function build_type(T, params)
|
||||||
quote
|
quote
|
||||||
type $T
|
type $T
|
||||||
$(params...)
|
$(params...)
|
||||||
$([symbol("Δ", s) for s in params]...)
|
$([symbol("Δ", s) for s in params]...)
|
||||||
$(temps...)
|
|
||||||
end
|
end
|
||||||
$T($(params...)) = $T($(params...),
|
$T($(params...)) = $T($(params...),
|
||||||
$((:(zeros($p)) for p in params)...),
|
$((:(zeros($p)) for p in params)...))
|
||||||
$((:nothing for t in temps)...))
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function build_forward(body, temps)
|
function build_forward(body, args)
|
||||||
body = cut_forward(body)
|
body = cut_forward(body, args)
|
||||||
forward = IVertex{Any}(Flow.Do())
|
cse(body)
|
||||||
for (ex, k) in temps
|
|
||||||
k = Expr(:quote, k)
|
|
||||||
thread!(forward, @v(setfield!(:self, k, ex)))
|
|
||||||
end
|
|
||||||
thread!(forward, body)
|
|
||||||
cse(forward)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
function build_backward(Δs, x, params, temps)
|
function build_backward(body, x, params)
|
||||||
|
Δs, Δloops = cut_backward(body, [x])
|
||||||
back = IVertex{Any}(Flow.Do())
|
back = IVertex{Any}(Flow.Do())
|
||||||
tempify(v) = prewalk(v -> haskey(temps, v) ? @v(:(self.$(temps[v]))) : v, v)
|
|
||||||
for param in params
|
for param in params
|
||||||
haskey(Δs, :(self.$param)) || continue
|
haskey(Δs, :(self.$param)) || continue
|
||||||
k = symbol("Δ", param)
|
k = symbol("Δ", param)
|
||||||
ksym = Expr(:quote, k)
|
ksym = Expr(:quote, k)
|
||||||
ex = tempify(Δs[:(self.$param)])
|
ex = Δs[:(self.$param)]
|
||||||
|
for Δloop in Δloops
|
||||||
|
ex = addΔ(ex, get(Δloop, :(self.$param), vertex(0)))
|
||||||
|
end
|
||||||
thread!(back, @v(setfield!(:self, ksym, :(self.$k) + ex)))
|
thread!(back, @v(setfield!(:self, ksym, :(self.$k) + ex)))
|
||||||
end
|
end
|
||||||
thread!(back, @flow(tuple($(tempify(Δs[x])))))
|
ex = Δs[x]
|
||||||
|
for Δloop in Δloops
|
||||||
|
ex = addΔ(ex, get(Δloop, x, vertex(0)))
|
||||||
|
end
|
||||||
|
thread!(back, @flow(tuple($ex)))
|
||||||
cse(back)
|
cse(back)
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -65,13 +56,12 @@ function process_type(ex)
|
|||||||
@destruct [params = true || [],
|
@destruct [params = true || [],
|
||||||
funcs = false || []] = groupby(x->isa(x, Symbol), fs)
|
funcs = false || []] = groupby(x->isa(x, Symbol), fs)
|
||||||
@assert length(funcs) == 1
|
@assert length(funcs) == 1
|
||||||
args, body, Δs = process_func(funcs[1], params)
|
args, body = process_func(funcs[1], params)
|
||||||
@assert length(args) == 1
|
@assert length(args) == 1
|
||||||
temps = forward_temporaries(body, Δs)
|
|
||||||
quote
|
quote
|
||||||
$(build_type(T, params, collect(values(temps))))
|
$(build_type(T, params))
|
||||||
(self::$T)($(args...),) = $(syntax(build_forward(body, temps)))
|
(self::$T)($(args...),) = $(syntax(build_forward(body, args)))
|
||||||
back!(self::$T, Δ, $(args...)) = $(syntax(build_backward(Δs, args[1], params, temps)))
|
back!(self::$T, Δ, $(args...)) = $(syntax(build_backward(body, args[1], params)))
|
||||||
$(build_update(T, params))
|
$(build_update(T, params))
|
||||||
end |> longdef
|
end |> longdef
|
||||||
end
|
end
|
||||||
|
@ -2,6 +2,8 @@ import Flow: isconstant, il, dl, cse, prewalk, graphm, syntax, @v
|
|||||||
|
|
||||||
vertex(a...) = IVertex{Any}(a...)
|
vertex(a...) = IVertex{Any}(a...)
|
||||||
|
|
||||||
|
addΔ(a, b) = vertex(:+, a, b)
|
||||||
|
|
||||||
# Special case a couple of operators to clean up output code
|
# Special case a couple of operators to clean up output code
|
||||||
const symbolic = Dict()
|
const symbolic = Dict()
|
||||||
|
|
||||||
@ -13,7 +15,7 @@ function ∇v(v::Vertex, Δ)
|
|||||||
map(i -> @flow(getindex($Δ, $i)), 1:Flow.nin(v))
|
map(i -> @flow(getindex($Δ, $i)), 1:Flow.nin(v))
|
||||||
end
|
end
|
||||||
|
|
||||||
function invert(v::IVertex, Δ, out = d())
|
function invert(v::IVertex, Δ = vertex(:Δ), out = d())
|
||||||
@assert !iscyclic(v)
|
@assert !iscyclic(v)
|
||||||
if isconstant(v)
|
if isconstant(v)
|
||||||
@assert !haskey(out, value(v))
|
@assert !haskey(out, value(v))
|
||||||
|
@ -7,13 +7,53 @@ function delays(v::IVertex)
|
|||||||
return ds
|
return ds
|
||||||
end
|
end
|
||||||
|
|
||||||
function cut_forward(v::IVertex)
|
function cut(v::IVertex, f = _ -> il(@flow(last(self.delay))))
|
||||||
pushes = map(x->vertex(:push!, vertex(:(self.delay)), v[1]), delays(v))
|
prewalk(v) do v
|
||||||
|
value(v) == :Delay ? f(v) : v
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
replaceall(d::Dict, args...) = Dict(k => replace(v, args...) for (k, v) in d)
|
||||||
|
|
||||||
|
# Create the forward function; a single delay node becomes an
|
||||||
|
# input and an output node.
|
||||||
|
function cut_forward(v::IVertex, params, ds = delays(v))
|
||||||
|
pushes = map(x->vertex(:push!, vertex(:(self.delay)), x[1], map(vertex, params)...), ds)
|
||||||
isempty(pushes) && return v
|
isempty(pushes) && return v
|
||||||
@assert length(pushes) == 1
|
@assert length(pushes) == 1
|
||||||
v = vertex(Flow.Do(), pushes..., v)
|
v = vertex(Flow.Do(), pushes..., v)
|
||||||
prewalk(v) do v
|
cut(v)
|
||||||
value(v) == :Delay || return v
|
|
||||||
il(@flow(pop!(self.delay)))
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
# Given a delay node, give the parameter gradients with respect to
|
||||||
|
# the node and a function which will propagate gradients around
|
||||||
|
# the loop.
|
||||||
|
function invertloop(v::IVertex, params)
|
||||||
|
@gensym input
|
||||||
|
v = cut(v[1], v -> vertex(input))
|
||||||
|
Δs = invert(v, @flow(Δloop))
|
||||||
|
Δs = replaceall(Δs, vertex(input), il(@flow(last(self.delay))))
|
||||||
|
Δs, :((Δ, $input, $(params...)) -> $(syntax(cse(Δs[input]))))
|
||||||
|
end
|
||||||
|
|
||||||
|
# Returns:
|
||||||
|
# Parameter gradients with respect to the function
|
||||||
|
# Parameter gradients with respect to each delay node
|
||||||
|
function cut_backward(v::IVertex, params, ds = delays(v))
|
||||||
|
isempty(ds) && return invert(v), []
|
||||||
|
@assert length(ds) == 1
|
||||||
|
@gensym input
|
||||||
|
Δs = invert(cut(v, _ -> vertex(input)))
|
||||||
|
Δs = replaceall(Δs, vertex(input), il(@flow(last(self.delay))))
|
||||||
|
Δloop, ∇loop = invertloop(ds[1], params)
|
||||||
|
Δh = vertex(:back!, vertex(:(self.delay)), Δs[input], vertex(∇loop))
|
||||||
|
Δloop = replaceall(Δloop, vertex(:Δloop), Δh)
|
||||||
|
Δs, [Δloop]
|
||||||
|
end
|
||||||
|
|
||||||
|
# g = il(@flow begin
|
||||||
|
# hidden = σ( Wxh*x + Whh*Delay(hidden) + bh )
|
||||||
|
# y = σ( Why*hidden + by )
|
||||||
|
# end)
|
||||||
|
|
||||||
|
# cut_backward(g, [:x])[1]
|
||||||
|
Loading…
Reference in New Issue
Block a user