diff --git a/src/compiler/code.jl b/src/compiler/code.jl index 74cca057..188b9116 100644 --- a/src/compiler/code.jl +++ b/src/compiler/code.jl @@ -1,53 +1,44 @@ -function forward_temporaries(body, Δs) - # exs = union((common(body, Δ) for Δ in values(Δs))...) - # filter!(ex -> !(@capture(value(ex), self._) || isconstant(ex)), exs) - # [ex=>symbol("temp", i) for (i, ex) in enumerate(exs)] - return Dict() -end - function process_func(ex, params) @capture(shortdef(ex), (args__,) -> body_) body = il(graphm(body)) body = map(x -> x in params ? :(self.$x) : x, body) - Δ = invert(body, @flow(Δ)) - return args, body, Δ + return args, body end -function build_type(T, params, temps) +function build_type(T, params) quote type $T $(params...) $([symbol("Δ", s) for s in params]...) - $(temps...) end $T($(params...)) = $T($(params...), - $((:(zeros($p)) for p in params)...), - $((:nothing for t in temps)...)) + $((:(zeros($p)) for p in params)...)) end end -function build_forward(body, temps) - body = cut_forward(body) - forward = IVertex{Any}(Flow.Do()) - for (ex, k) in temps - k = Expr(:quote, k) - thread!(forward, @v(setfield!(:self, k, ex))) - end - thread!(forward, body) - cse(forward) +function build_forward(body, args) + body = cut_forward(body, args) + cse(body) end -function build_backward(Δs, x, params, temps) +function build_backward(body, x, params) + Δs, Δloops = cut_backward(body, [x]) back = IVertex{Any}(Flow.Do()) - tempify(v) = prewalk(v -> haskey(temps, v) ? @v(:(self.$(temps[v]))) : v, v) for param in params haskey(Δs, :(self.$param)) || continue k = symbol("Δ", param) ksym = Expr(:quote, k) - ex = tempify(Δs[:(self.$param)]) + ex = Δs[:(self.$param)] + for Δloop in Δloops + ex = addΔ(ex, get(Δloop, :(self.$param), vertex(0))) + end thread!(back, @v(setfield!(:self, ksym, :(self.$k) + ex))) end - thread!(back, @flow(tuple($(tempify(Δs[x]))))) + ex = Δs[x] + for Δloop in Δloops + ex = addΔ(ex, get(Δloop, x, vertex(0))) + end + thread!(back, @flow(tuple($ex))) cse(back) end @@ -65,13 +56,12 @@ function process_type(ex) @destruct [params = true || [], funcs = false || []] = groupby(x->isa(x, Symbol), fs) @assert length(funcs) == 1 - args, body, Δs = process_func(funcs[1], params) + args, body = process_func(funcs[1], params) @assert length(args) == 1 - temps = forward_temporaries(body, Δs) quote - $(build_type(T, params, collect(values(temps)))) - (self::$T)($(args...),) = $(syntax(build_forward(body, temps))) - back!(self::$T, Δ, $(args...)) = $(syntax(build_backward(Δs, args[1], params, temps))) + $(build_type(T, params)) + (self::$T)($(args...),) = $(syntax(build_forward(body, args))) + back!(self::$T, Δ, $(args...)) = $(syntax(build_backward(body, args[1], params))) $(build_update(T, params)) end |> longdef end diff --git a/src/compiler/diff.jl b/src/compiler/diff.jl index e98c8033..dfc4b960 100644 --- a/src/compiler/diff.jl +++ b/src/compiler/diff.jl @@ -2,6 +2,8 @@ import Flow: isconstant, il, dl, cse, prewalk, graphm, syntax, @v vertex(a...) = IVertex{Any}(a...) +addΔ(a, b) = vertex(:+, a, b) + # Special case a couple of operators to clean up output code const symbolic = Dict() @@ -13,7 +15,7 @@ function ∇v(v::Vertex, Δ) map(i -> @flow(getindex($Δ, $i)), 1:Flow.nin(v)) end -function invert(v::IVertex, Δ, out = d()) +function invert(v::IVertex, Δ = vertex(:Δ), out = d()) @assert !iscyclic(v) if isconstant(v) @assert !haskey(out, value(v)) diff --git a/src/compiler/loop.jl b/src/compiler/loop.jl index 2b664012..91c02deb 100644 --- a/src/compiler/loop.jl +++ b/src/compiler/loop.jl @@ -7,13 +7,53 @@ function delays(v::IVertex) return ds end -function cut_forward(v::IVertex) - pushes = map(x->vertex(:push!, vertex(:(self.delay)), v[1]), delays(v)) +function cut(v::IVertex, f = _ -> il(@flow(last(self.delay)))) + prewalk(v) do v + value(v) == :Delay ? f(v) : v + end +end + +replaceall(d::Dict, args...) = Dict(k => replace(v, args...) for (k, v) in d) + +# Create the forward function; a single delay node becomes an +# input and an output node. +function cut_forward(v::IVertex, params, ds = delays(v)) + pushes = map(x->vertex(:push!, vertex(:(self.delay)), x[1], map(vertex, params)...), ds) isempty(pushes) && return v @assert length(pushes) == 1 v = vertex(Flow.Do(), pushes..., v) - prewalk(v) do v - value(v) == :Delay || return v - il(@flow(pop!(self.delay))) - end + cut(v) end + +# Given a delay node, give the parameter gradients with respect to +# the node and a function which will propagate gradients around +# the loop. +function invertloop(v::IVertex, params) + @gensym input + v = cut(v[1], v -> vertex(input)) + Δs = invert(v, @flow(Δloop)) + Δs = replaceall(Δs, vertex(input), il(@flow(last(self.delay)))) + Δs, :((Δ, $input, $(params...)) -> $(syntax(cse(Δs[input])))) +end + +# Returns: +# Parameter gradients with respect to the function +# Parameter gradients with respect to each delay node +function cut_backward(v::IVertex, params, ds = delays(v)) + isempty(ds) && return invert(v), [] + @assert length(ds) == 1 + @gensym input + Δs = invert(cut(v, _ -> vertex(input))) + Δs = replaceall(Δs, vertex(input), il(@flow(last(self.delay)))) + Δloop, ∇loop = invertloop(ds[1], params) + Δh = vertex(:back!, vertex(:(self.delay)), Δs[input], vertex(∇loop)) + Δloop = replaceall(Δloop, vertex(:Δloop), Δh) + Δs, [Δloop] +end + +# g = il(@flow begin +# hidden = σ( Wxh*x + Whh*Delay(hidden) + bh ) +# y = σ( Why*hidden + by ) +# end) + +# cut_backward(g, [:x])[1]