unify functions and models in autodiff

This commit is contained in:
Mike J Innes 2016-06-06 12:36:59 +01:00
parent 12f8fea206
commit 3bb6f2cdd0
2 changed files with 20 additions and 28 deletions

View File

@ -1,22 +1,14 @@
function forward_temporaries(body, ∇s) function forward_temporaries(body, ∇s)
exs = union((common(body, ) for in values(∇s))...) exs = union((common(body, ) for in values(∇s))...)
filter!(ex -> !@capture(value(ex), self._), exs) filter!(ex -> !(@capture(value(ex), self._) || isconstant(ex)), exs)
[ex=>symbol("temp", i) for (i, ex) in enumerate(exs)] [ex=>symbol("temp", i) for (i, ex) in enumerate(exs)]
end end
resolve_calls(ex) = ex
function resolve_calls(ex::Expr)
@capture(ex, f_(a__)) ?
Expr(:call, eval(current_module(), f), map(resolve_calls, a)...) :
Expr(ex.head, map(resolve_calls, ex.args))
end
function process_func(ex, params) function process_func(ex, params)
@capture(shortdef(ex), (args__,) -> body_) @capture(shortdef(ex), (args__,) -> body_)
body = il(graphm(resolve_calls(body))) body = il(graphm(body))
body = map(x -> x in params ? :(self.$x) : x, body) body = map(x -> x in params ? :(self.$x) : x, body)
= ∇graph(body, @flow()) = invert(body, @flow())
return args, body, return args, body,
end end
@ -74,11 +66,10 @@ function process_type(ex)
args, body, ∇s = process_func(funcs[1], params) args, body, ∇s = process_func(funcs[1], params)
@assert length(args) == 1 @assert length(args) == 1
temps = forward_temporaries(body, ∇s) temps = forward_temporaries(body, ∇s)
∇s
quote quote
$(build_type(T, params, collect(values(temps)))) $(build_type(T, params, collect(values(temps))))
(self::$T)($(args...),) = $(syntax(build_forward(body, temps))) (self::$T)($(args...),) = $(syntax(build_forward(body, temps)))
back!(self::$T, ) = $(syntax(build_backward(∇s, args[1], params, temps))) back!(self::$T, , $(args...)) = $(syntax(build_backward(∇s, args[1], params, temps)))
$(build_update(T, params)) $(build_update(T, params))
end |> longdef |> MacroTools.flatten end |> longdef |> MacroTools.flatten
end end

View File

@ -2,29 +2,30 @@ import Flow: isconstant, il, dl, cse, prewalk, graphm, syntax, @v
vertex(a...) = IVertex{Any}(a...) vertex(a...) = IVertex{Any}(a...)
∇graph(f, , a) = (@v( .* ∇₁(f)(a)),) # Special case a couple of operators to clean up output code
const symbolic = Dict()
∇graph(::typeof(+), , a...) = ( for _ in a) symbolic[:+] = (Δ, args...) -> map(_->Δ, args)
∇graph(::typeof(-), , a, b) = , @v(-) function ∇v(v::Vertex, Δ)
haskey(symbolic, value(v)) && return symbolic[value(v)](Δ, inputs(v)...)
Δ = vertex(:back!, vertex(value(v)), Δ, inputs(v)...)
map(i -> @flow(getindex($Δ, $i)), 1:Flow.nin(v))
end
∇graph(::typeof(*), , a, b) = map(x->@v( * transpose(x)), (b, a)) function invert(v::IVertex, Δ, out = d())
function ∇graph(v::IVertex, , out = d())
if isconstant(v) if isconstant(v)
@assert !haskey(out, value(v)) @assert !haskey(out, value(v))
out[value(v)] = il() out[value(v)] = il(Δ)
else else
s = ∇graph(value(v), , inputs(v)...) Δs = ∇v(v, Δ)
for (v, ) in zip(inputs(v), s) for (v, Δ) in zip(inputs(v), Δs)
∇graph(v, , out) invert(v, Δ, out)
end end
end end
return out return out
end end
macro derive(ex) back!(::typeof(+), Δ, args...) = map(_ -> Δ, args)
∇s = ∇graph(il(graphm(resolve_calls(ex))), @flow())
v = vertex(Flow.Do(), (@v(Flow.Assign(Symbol("", k))(v)) for (k, v) in ∇s)...) back!(::typeof(*), Δ, a, b) = Δ*b', Δ*a'
Expr(:quote, @> v cse syntax prettify)
end