unify functions and models in autodiff
This commit is contained in:
parent
12f8fea206
commit
3bb6f2cdd0
@ -1,22 +1,14 @@
|
||||
function forward_temporaries(body, ∇s)
|
||||
exs = union((common(body, ∇) for ∇ in values(∇s))...)
|
||||
filter!(ex -> !@capture(value(ex), self._), exs)
|
||||
filter!(ex -> !(@capture(value(ex), self._) || isconstant(ex)), exs)
|
||||
[ex=>symbol("temp", i) for (i, ex) in enumerate(exs)]
|
||||
end
|
||||
|
||||
resolve_calls(ex) = ex
|
||||
|
||||
function resolve_calls(ex::Expr)
|
||||
@capture(ex, f_(a__)) ?
|
||||
Expr(:call, eval(current_module(), f), map(resolve_calls, a)...) :
|
||||
Expr(ex.head, map(resolve_calls, ex.args))
|
||||
end
|
||||
|
||||
function process_func(ex, params)
|
||||
@capture(shortdef(ex), (args__,) -> body_)
|
||||
body = il(graphm(resolve_calls(body)))
|
||||
body = il(graphm(body))
|
||||
body = map(x -> x in params ? :(self.$x) : x, body)
|
||||
∇ = ∇graph(body, @flow(∇))
|
||||
∇ = invert(body, @flow(∇))
|
||||
return args, body, ∇
|
||||
end
|
||||
|
||||
@ -74,11 +66,10 @@ function process_type(ex)
|
||||
args, body, ∇s = process_func(funcs[1], params)
|
||||
@assert length(args) == 1
|
||||
temps = forward_temporaries(body, ∇s)
|
||||
∇s
|
||||
quote
|
||||
$(build_type(T, params, collect(values(temps))))
|
||||
(self::$T)($(args...),) = $(syntax(build_forward(body, temps)))
|
||||
back!(self::$T, ∇) = $(syntax(build_backward(∇s, args[1], params, temps)))
|
||||
back!(self::$T, ∇, $(args...)) = $(syntax(build_backward(∇s, args[1], params, temps)))
|
||||
$(build_update(T, params))
|
||||
end |> longdef |> MacroTools.flatten
|
||||
end
|
||||
|
@ -2,29 +2,30 @@ import Flow: isconstant, il, dl, cse, prewalk, graphm, syntax, @v
|
||||
|
||||
vertex(a...) = IVertex{Any}(a...)
|
||||
|
||||
∇graph(f, ∇, a) = (@v(∇ .* ∇₁(f)(a)),)
|
||||
# Special case a couple of operators to clean up output code
|
||||
const symbolic = Dict()
|
||||
|
||||
∇graph(::typeof(+), ∇, a...) = (∇ for _ in a)
|
||||
symbolic[:+] = (Δ, args...) -> map(_->Δ, args)
|
||||
|
||||
∇graph(::typeof(-), ∇, a, b) = ∇, @v(-∇)
|
||||
function ∇v(v::Vertex, Δ)
|
||||
haskey(symbolic, value(v)) && return symbolic[value(v)](Δ, inputs(v)...)
|
||||
Δ = vertex(:back!, vertex(value(v)), Δ, inputs(v)...)
|
||||
map(i -> @flow(getindex($Δ, $i)), 1:Flow.nin(v))
|
||||
end
|
||||
|
||||
∇graph(::typeof(*), ∇, a, b) = map(x->@v(∇ * transpose(x)), (b, a))
|
||||
|
||||
function ∇graph(v::IVertex, ∇, out = d())
|
||||
function invert(v::IVertex, Δ, out = d())
|
||||
if isconstant(v)
|
||||
@assert !haskey(out, value(v))
|
||||
out[value(v)] = il(∇)
|
||||
out[value(v)] = il(Δ)
|
||||
else
|
||||
∇′s = ∇graph(value(v), ∇, inputs(v)...)
|
||||
for (v′, ∇′) in zip(inputs(v), ∇′s)
|
||||
∇graph(v′, ∇′, out)
|
||||
Δ′s = ∇v(v, Δ)
|
||||
for (v′, Δ′) in zip(inputs(v), Δ′s)
|
||||
invert(v′, Δ′, out)
|
||||
end
|
||||
end
|
||||
return out
|
||||
end
|
||||
|
||||
macro derive(ex)
|
||||
∇s = ∇graph(il(graphm(resolve_calls(ex))), @flow(∇))
|
||||
v = vertex(Flow.Do(), (@v(Flow.Assign(Symbol("∇", k))(v)) for (k, v) in ∇s)...)
|
||||
Expr(:quote, @> v cse syntax prettify)
|
||||
end
|
||||
back!(::typeof(+), Δ, args...) = map(_ -> Δ, args)
|
||||
|
||||
back!(::typeof(*), Δ, a, b) = Δ*b', Δ*a'
|
||||
|
Loading…
Reference in New Issue
Block a user