2016-08-18 21:06:12 +00:00
|
|
|
|
import Flow: isconstant, il, dl, cse, prewalk, graphm, syntax, @vertex
|
2016-04-01 21:11:42 +00:00
|
|
|
|
|
|
|
|
|
vertex(a...) = IVertex{Any}(a...)
|
2016-08-18 21:06:12 +00:00
|
|
|
|
vertex(v::Vertex) = convert(IVertex{Any}, v)
|
|
|
|
|
constant(x) = vertex(Flow.Constant(x))
|
|
|
|
|
constant(x::Vertex) = vertex(x)
|
2016-04-01 21:11:42 +00:00
|
|
|
|
|
|
|
|
|
addΔ(a, b) = vertex(:+, a, b)
|
|
|
|
|
|
|
|
|
|
# Special case a couple of operators to clean up output code
|
|
|
|
|
const symbolic = Dict()
|
|
|
|
|
|
|
|
|
|
symbolic[:+] = (Δ, args...) -> map(_->Δ, args)
|
|
|
|
|
|
|
|
|
|
function ∇v(v::Vertex, Δ)
|
|
|
|
|
haskey(symbolic, value(v)) && return symbolic[value(v)](Δ, inputs(v)...)
|
2016-08-18 21:06:12 +00:00
|
|
|
|
Δ = vertex(:back!, constant(value(v)), constant(Δ), inputs(v)...)
|
2016-04-01 21:11:42 +00:00
|
|
|
|
map(i -> @flow(getindex($Δ, $i)), 1:Flow.nin(v))
|
|
|
|
|
end
|
|
|
|
|
|
2016-08-18 21:06:12 +00:00
|
|
|
|
function invert(v::IVertex, Δ = constant(:Δ), out = d())
|
2016-04-01 21:11:42 +00:00
|
|
|
|
@assert !iscyclic(v)
|
|
|
|
|
if isconstant(v)
|
|
|
|
|
@assert !haskey(out, value(v))
|
2016-08-18 21:06:12 +00:00
|
|
|
|
out[value(v).value] = il(Δ)
|
2016-04-01 21:11:42 +00:00
|
|
|
|
else
|
|
|
|
|
Δ′s = ∇v(v, Δ)
|
|
|
|
|
for (v′, Δ′) in zip(inputs(v), Δ′s)
|
|
|
|
|
invert(v′, Δ′, out)
|
|
|
|
|
end
|
|
|
|
|
end
|
|
|
|
|
return out
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
back!(::typeof(+), Δ, args...) = map(_ -> Δ, args)
|
|
|
|
|
|
|
|
|
|
back!(::typeof(*), Δ, a, b) = Δ*b', Δ*a'
|