recurrence semi-working

This commit is contained in:
Mike J Innes 2016-06-08 01:57:14 +01:00
parent 5c018bcae7
commit 1480e1fab6
3 changed files with 70 additions and 38 deletions

View File

@ -1,53 +1,44 @@
function forward_temporaries(body, Δs)
# exs = union((common(body, Δ) for Δ in values(Δs))...)
# filter!(ex -> !(@capture(value(ex), self._) || isconstant(ex)), exs)
# [ex=>symbol("temp", i) for (i, ex) in enumerate(exs)]
return Dict()
end
function process_func(ex, params) function process_func(ex, params)
@capture(shortdef(ex), (args__,) -> body_) @capture(shortdef(ex), (args__,) -> body_)
body = il(graphm(body)) body = il(graphm(body))
body = map(x -> x in params ? :(self.$x) : x, body) body = map(x -> x in params ? :(self.$x) : x, body)
Δ = invert(body, @flow(Δ)) return args, body
return args, body, Δ
end end
function build_type(T, params, temps) function build_type(T, params)
quote quote
type $T type $T
$(params...) $(params...)
$([symbol("Δ", s) for s in params]...) $([symbol("Δ", s) for s in params]...)
$(temps...)
end end
$T($(params...)) = $T($(params...), $T($(params...)) = $T($(params...),
$((:(zeros($p)) for p in params)...), $((:(zeros($p)) for p in params)...))
$((:nothing for t in temps)...))
end end
end end
function build_forward(body, temps) function build_forward(body, args)
body = cut_forward(body) body = cut_forward(body, args)
forward = IVertex{Any}(Flow.Do()) cse(body)
for (ex, k) in temps
k = Expr(:quote, k)
thread!(forward, @v(setfield!(:self, k, ex)))
end
thread!(forward, body)
cse(forward)
end end
function build_backward(Δs, x, params, temps) function build_backward(body, x, params)
Δs, Δloops = cut_backward(body, [x])
back = IVertex{Any}(Flow.Do()) back = IVertex{Any}(Flow.Do())
tempify(v) = prewalk(v -> haskey(temps, v) ? @v(:(self.$(temps[v]))) : v, v)
for param in params for param in params
haskey(Δs, :(self.$param)) || continue haskey(Δs, :(self.$param)) || continue
k = symbol("Δ", param) k = symbol("Δ", param)
ksym = Expr(:quote, k) ksym = Expr(:quote, k)
ex = tempify(Δs[:(self.$param)]) ex = Δs[:(self.$param)]
for Δloop in Δloops
ex = addΔ(ex, get(Δloop, :(self.$param), vertex(0)))
end
thread!(back, @v(setfield!(:self, ksym, :(self.$k) + ex))) thread!(back, @v(setfield!(:self, ksym, :(self.$k) + ex)))
end end
thread!(back, @flow(tuple($(tempify(Δs[x]))))) ex = Δs[x]
for Δloop in Δloops
ex = addΔ(ex, get(Δloop, x, vertex(0)))
end
thread!(back, @flow(tuple($ex)))
cse(back) cse(back)
end end
@ -65,13 +56,12 @@ function process_type(ex)
@destruct [params = true || [], @destruct [params = true || [],
funcs = false || []] = groupby(x->isa(x, Symbol), fs) funcs = false || []] = groupby(x->isa(x, Symbol), fs)
@assert length(funcs) == 1 @assert length(funcs) == 1
args, body, Δs = process_func(funcs[1], params) args, body = process_func(funcs[1], params)
@assert length(args) == 1 @assert length(args) == 1
temps = forward_temporaries(body, Δs)
quote quote
$(build_type(T, params, collect(values(temps)))) $(build_type(T, params))
(self::$T)($(args...),) = $(syntax(build_forward(body, temps))) (self::$T)($(args...),) = $(syntax(build_forward(body, args)))
back!(self::$T, Δ, $(args...)) = $(syntax(build_backward(Δs, args[1], params, temps))) back!(self::$T, Δ, $(args...)) = $(syntax(build_backward(body, args[1], params)))
$(build_update(T, params)) $(build_update(T, params))
end |> longdef end |> longdef
end end

View File

@ -2,6 +2,8 @@ import Flow: isconstant, il, dl, cse, prewalk, graphm, syntax, @v
vertex(a...) = IVertex{Any}(a...) vertex(a...) = IVertex{Any}(a...)
addΔ(a, b) = vertex(:+, a, b)
# Special case a couple of operators to clean up output code # Special case a couple of operators to clean up output code
const symbolic = Dict() const symbolic = Dict()
@ -13,7 +15,7 @@ function ∇v(v::Vertex, Δ)
map(i -> @flow(getindex($Δ, $i)), 1:Flow.nin(v)) map(i -> @flow(getindex($Δ, $i)), 1:Flow.nin(v))
end end
function invert(v::IVertex, Δ, out = d()) function invert(v::IVertex, Δ = vertex(), out = d())
@assert !iscyclic(v) @assert !iscyclic(v)
if isconstant(v) if isconstant(v)
@assert !haskey(out, value(v)) @assert !haskey(out, value(v))

View File

@ -7,13 +7,53 @@ function delays(v::IVertex)
return ds return ds
end end
function cut_forward(v::IVertex) function cut(v::IVertex, f = _ -> il(@flow(last(self.delay))))
pushes = map(x->vertex(:push!, vertex(:(self.delay)), v[1]), delays(v)) prewalk(v) do v
value(v) == :Delay ? f(v) : v
end
end
replaceall(d::Dict, args...) = Dict(k => replace(v, args...) for (k, v) in d)
# Create the forward function; a single delay node becomes an
# input and an output node.
function cut_forward(v::IVertex, params, ds = delays(v))
pushes = map(x->vertex(:push!, vertex(:(self.delay)), x[1], map(vertex, params)...), ds)
isempty(pushes) && return v isempty(pushes) && return v
@assert length(pushes) == 1 @assert length(pushes) == 1
v = vertex(Flow.Do(), pushes..., v) v = vertex(Flow.Do(), pushes..., v)
prewalk(v) do v cut(v)
value(v) == :Delay || return v
il(@flow(pop!(self.delay)))
end
end end
# Given a delay node, give the parameter gradients with respect to
# the node and a function which will propagate gradients around
# the loop.
function invertloop(v::IVertex, params)
@gensym input
v = cut(v[1], v -> vertex(input))
Δs = invert(v, @flow(Δloop))
Δs = replaceall(Δs, vertex(input), il(@flow(last(self.delay))))
Δs, :((Δ, $input, $(params...)) -> $(syntax(cse(Δs[input]))))
end
# Returns:
# Parameter gradients with respect to the function
# Parameter gradients with respect to each delay node
function cut_backward(v::IVertex, params, ds = delays(v))
isempty(ds) && return invert(v), []
@assert length(ds) == 1
@gensym input
Δs = invert(cut(v, _ -> vertex(input)))
Δs = replaceall(Δs, vertex(input), il(@flow(last(self.delay))))
Δloop, ∇loop = invertloop(ds[1], params)
Δh = vertex(:back!, vertex(:(self.delay)), Δs[input], vertex(∇loop))
Δloop = replaceall(Δloop, vertex(:Δloop), Δh)
Δs, [Δloop]
end
# g = il(@flow begin
# hidden = σ( Wxh*x + Whh*Delay(hidden) + bh )
# y = σ( Why*hidden + by )
# end)
# cut_backward(g, [:x])[1]