Flux.jl/src/compiler/diff.jl

31 lines
806 B
Julia
Raw Normal View History

2016-04-01 21:11:42 +00:00
addΔ(a, b) = vertex(:+, a, b)
# Special case a couple of operators to clean up output code
const symbolic = Dict()
symbolic[:+] = (Δ, args...) -> map(_->Δ, args)
function ∇v(v::Vertex, Δ)
haskey(symbolic, value(v)) && return symbolic[value(v)](Δ, inputs(v)...)
2016-08-18 21:06:12 +00:00
Δ = vertex(:back!, constant(value(v)), constant(Δ), inputs(v)...)
2016-04-01 21:11:42 +00:00
map(i -> @flow(getindex($Δ, $i)), 1:Flow.nin(v))
end
2016-08-18 21:31:49 +00:00
function invert(v::IVertex, Δ = , out = d())
2016-04-01 21:11:42 +00:00
@assert !iscyclic(v)
if isconstant(v)
@assert !haskey(out, value(v))
2016-08-18 21:31:49 +00:00
out[value(v).value] = constant(Δ)
2016-04-01 21:11:42 +00:00
else
Δs = ∇v(v, Δ)
for (v, Δ′) in zip(inputs(v), Δs)
invert(v, Δ′, out)
end
end
return out
end
back!(::typeof(+), Δ, args...) = map(_ -> Δ, args)
back!(::typeof(*), Δ, a, b) = Δ*b', Δ*a'