diff --git a/src/rt/diff.jl b/src/rt/diff.jl index 90a162c0..db07c8b8 100644 --- a/src/rt/diff.jl +++ b/src/rt/diff.jl @@ -1,35 +1,24 @@ -import Flow: isconstant, il, dl, cse, prewalk, graphm, syntax +import Flow: isconstant, il, dl, cse, prewalk, graphm, syntax, @v vertex(a...) = IVertex{Any}(a...) -∇(f, a::IVertex) = - (v(∇₁(f), a),) +∇graph(f, ∇, a) = (@v(∇ .* ∇₁(f)(a)),) -∇(::typeof(+), a::IVertex, b::IVertex) = - v(1), v(1) +∇graph(::typeof(+), ∇, a, b) = ∇, ∇ -∇(::typeof(-), a::IVertex, b::IVertex) = - v(1), v(-1) +∇graph(::typeof(-), ∇, a, b) = ∇, @v(-∇) -∇(::typeof(*), a::IVertex, b::IVertex) = - v(transpose, b), v(transpose, a) +∇graph(::typeof(*), ∇, a, b) = map(x->@v(∇ * transpose(x)), (b, a)) -function ∇v(v::IVertex, chain::Vector{IVertex}, out = d()) +function ∇graph(v::IVertex, ∇, out = d()) if isconstant(v) @assert !haskey(out, value(v)) - out[value(v)] = length(chain) == 1 ? - first(chain) : - foldl((x, y) -> vertex(*, x, y), chain) + out[value(v)] = ∇ else - ∇s = ∇(value(v), inputs(v)...) - for (v′, ∇′) in zip(inputs(v), ∇s) - ∇v(v′, (value(∇′) ≠ 1 ? push!(copy(chain), ∇′) : chain), out) + ∇′s = ∇graph(value(v), ∇, inputs(v)...) + for (v′, ∇′) in zip(inputs(v), ∇′s) + ∇graph(v′, ∇′, out) end end return out end - -∇v(v::Vertex, chain::Vector) = ∇v(convert(IVertex, v), convert(Vector{IVertex}, chain)) - -∇v(v::Vertex, ∂::Vertex) = ∇v(v, [∂]) -