unify functions and models in autodiff

This commit is contained in:
Mike J Innes 2016-06-06 12:36:59 +01:00
parent 12f8fea206
commit 3bb6f2cdd0
2 changed files with 20 additions and 28 deletions

View File

@ -1,22 +1,14 @@
function forward_temporaries(body, ∇s)
exs = union((common(body, ) for in values(∇s))...)
filter!(ex -> !@capture(value(ex), self._), exs)
filter!(ex -> !(@capture(value(ex), self._) || isconstant(ex)), exs)
[ex=>symbol("temp", i) for (i, ex) in enumerate(exs)]
end
resolve_calls(ex) = ex
function resolve_calls(ex::Expr)
@capture(ex, f_(a__)) ?
Expr(:call, eval(current_module(), f), map(resolve_calls, a)...) :
Expr(ex.head, map(resolve_calls, ex.args))
end
function process_func(ex, params)
@capture(shortdef(ex), (args__,) -> body_)
body = il(graphm(resolve_calls(body)))
body = il(graphm(body))
body = map(x -> x in params ? :(self.$x) : x, body)
= ∇graph(body, @flow())
= invert(body, @flow())
return args, body,
end
@ -74,11 +66,10 @@ function process_type(ex)
args, body, ∇s = process_func(funcs[1], params)
@assert length(args) == 1
temps = forward_temporaries(body, ∇s)
∇s
quote
$(build_type(T, params, collect(values(temps))))
(self::$T)($(args...),) = $(syntax(build_forward(body, temps)))
back!(self::$T, ) = $(syntax(build_backward(∇s, args[1], params, temps)))
back!(self::$T, , $(args...)) = $(syntax(build_backward(∇s, args[1], params, temps)))
$(build_update(T, params))
end |> longdef |> MacroTools.flatten
end

View File

@ -2,29 +2,30 @@ import Flow: isconstant, il, dl, cse, prewalk, graphm, syntax, @v
vertex(a...) = IVertex{Any}(a...)
∇graph(f, , a) = (@v( .* ∇₁(f)(a)),)
# Special case a couple of operators to clean up output code
const symbolic = Dict()
∇graph(::typeof(+), , a...) = ( for _ in a)
symbolic[:+] = (Δ, args...) -> map(_->Δ, args)
∇graph(::typeof(-), , a, b) = , @v(-)
function ∇v(v::Vertex, Δ)
haskey(symbolic, value(v)) && return symbolic[value(v)](Δ, inputs(v)...)
Δ = vertex(:back!, vertex(value(v)), Δ, inputs(v)...)
map(i -> @flow(getindex($Δ, $i)), 1:Flow.nin(v))
end
∇graph(::typeof(*), , a, b) = map(x->@v( * transpose(x)), (b, a))
function ∇graph(v::IVertex, , out = d())
function invert(v::IVertex, Δ, out = d())
if isconstant(v)
@assert !haskey(out, value(v))
out[value(v)] = il()
out[value(v)] = il(Δ)
else
s = ∇graph(value(v), , inputs(v)...)
for (v, ) in zip(inputs(v), s)
∇graph(v, , out)
Δs = ∇v(v, Δ)
for (v, Δ) in zip(inputs(v), Δs)
invert(v, Δ, out)
end
end
return out
end
macro derive(ex)
∇s = ∇graph(il(graphm(resolve_calls(ex))), @flow())
v = vertex(Flow.Do(), (@v(Flow.Assign(Symbol("", k))(v)) for (k, v) in ∇s)...)
Expr(:quote, @> v cse syntax prettify)
end
back!(::typeof(+), Δ, args...) = map(_ -> Δ, args)
back!(::typeof(*), Δ, a, b) = Δ*b', Δ*a'