diff --git a/src/rt/code.jl b/src/rt/code.jl index feffa258..28acab2d 100644 --- a/src/rt/code.jl +++ b/src/rt/code.jl @@ -1,22 +1,14 @@ function forward_temporaries(body, ∇s) exs = union((common(body, ∇) for ∇ in values(∇s))...) - filter!(ex -> !@capture(value(ex), self._), exs) + filter!(ex -> !(@capture(value(ex), self._) || isconstant(ex)), exs) [ex=>symbol("temp", i) for (i, ex) in enumerate(exs)] end -resolve_calls(ex) = ex - -function resolve_calls(ex::Expr) - @capture(ex, f_(a__)) ? - Expr(:call, eval(current_module(), f), map(resolve_calls, a)...) : - Expr(ex.head, map(resolve_calls, ex.args)) -end - function process_func(ex, params) @capture(shortdef(ex), (args__,) -> body_) - body = il(graphm(resolve_calls(body))) + body = il(graphm(body)) body = map(x -> x in params ? :(self.$x) : x, body) - ∇ = ∇graph(body, @flow(∇)) + ∇ = invert(body, @flow(∇)) return args, body, ∇ end @@ -74,11 +66,10 @@ function process_type(ex) args, body, ∇s = process_func(funcs[1], params) @assert length(args) == 1 temps = forward_temporaries(body, ∇s) - ∇s quote $(build_type(T, params, collect(values(temps)))) (self::$T)($(args...),) = $(syntax(build_forward(body, temps))) - back!(self::$T, ∇) = $(syntax(build_backward(∇s, args[1], params, temps))) + back!(self::$T, ∇, $(args...)) = $(syntax(build_backward(∇s, args[1], params, temps))) $(build_update(T, params)) end |> longdef |> MacroTools.flatten end diff --git a/src/rt/diff.jl b/src/rt/diff.jl index 88eafd70..4cecd9c3 100644 --- a/src/rt/diff.jl +++ b/src/rt/diff.jl @@ -2,29 +2,30 @@ import Flow: isconstant, il, dl, cse, prewalk, graphm, syntax, @v vertex(a...) = IVertex{Any}(a...) -∇graph(f, ∇, a) = (@v(∇ .* ∇₁(f)(a)),) +# Special case a couple of operators to clean up output code +const symbolic = Dict() -∇graph(::typeof(+), ∇, a...) = (∇ for _ in a) +symbolic[:+] = (Δ, args...) -> map(_->Δ, args) -∇graph(::typeof(-), ∇, a, b) = ∇, @v(-∇) +function ∇v(v::Vertex, Δ) + haskey(symbolic, value(v)) && return symbolic[value(v)](Δ, inputs(v)...) + Δ = vertex(:back!, vertex(value(v)), Δ, inputs(v)...) + map(i -> @flow(getindex($Δ, $i)), 1:Flow.nin(v)) +end -∇graph(::typeof(*), ∇, a, b) = map(x->@v(∇ * transpose(x)), (b, a)) - -function ∇graph(v::IVertex, ∇, out = d()) +function invert(v::IVertex, Δ, out = d()) if isconstant(v) @assert !haskey(out, value(v)) - out[value(v)] = il(∇) + out[value(v)] = il(Δ) else - ∇′s = ∇graph(value(v), ∇, inputs(v)...) - for (v′, ∇′) in zip(inputs(v), ∇′s) - ∇graph(v′, ∇′, out) + Δ′s = ∇v(v, Δ) + for (v′, Δ′) in zip(inputs(v), Δ′s) + invert(v′, Δ′, out) end end return out end -macro derive(ex) - ∇s = ∇graph(il(graphm(resolve_calls(ex))), @flow(∇)) - v = vertex(Flow.Do(), (@v(Flow.Assign(Symbol("∇", k))(v)) for (k, v) in ∇s)...) - Expr(:quote, @> v cse syntax prettify) -end +back!(::typeof(+), Δ, args...) = map(_ -> Δ, args) + +back!(::typeof(*), Δ, a, b) = Δ*b', Δ*a'