simplify things
This commit is contained in:
parent
eaa77cc5a6
commit
2706beaf3f
@ -1,35 +1,24 @@
|
|||||||
import Flow: isconstant, il, dl, cse, prewalk, graphm, syntax
|
import Flow: isconstant, il, dl, cse, prewalk, graphm, syntax, @v
|
||||||
|
|
||||||
vertex(a...) = IVertex{Any}(a...)
|
vertex(a...) = IVertex{Any}(a...)
|
||||||
|
|
||||||
∇(f, a::IVertex) =
|
∇graph(f, ∇, a) = (@v(∇ .* ∇₁(f)(a)),)
|
||||||
(v(∇₁(f), a),)
|
|
||||||
|
|
||||||
∇(::typeof(+), a::IVertex, b::IVertex) =
|
∇graph(::typeof(+), ∇, a, b) = ∇, ∇
|
||||||
v(1), v(1)
|
|
||||||
|
|
||||||
∇(::typeof(-), a::IVertex, b::IVertex) =
|
∇graph(::typeof(-), ∇, a, b) = ∇, @v(-∇)
|
||||||
v(1), v(-1)
|
|
||||||
|
|
||||||
∇(::typeof(*), a::IVertex, b::IVertex) =
|
∇graph(::typeof(*), ∇, a, b) = map(x->@v(∇ * transpose(x)), (b, a))
|
||||||
v(transpose, b), v(transpose, a)
|
|
||||||
|
|
||||||
function ∇v(v::IVertex, chain::Vector{IVertex}, out = d())
|
function ∇graph(v::IVertex, ∇, out = d())
|
||||||
if isconstant(v)
|
if isconstant(v)
|
||||||
@assert !haskey(out, value(v))
|
@assert !haskey(out, value(v))
|
||||||
out[value(v)] = length(chain) == 1 ?
|
out[value(v)] = ∇
|
||||||
first(chain) :
|
|
||||||
foldl((x, y) -> vertex(*, x, y), chain)
|
|
||||||
else
|
else
|
||||||
∇s = ∇(value(v), inputs(v)...)
|
∇′s = ∇graph(value(v), ∇, inputs(v)...)
|
||||||
for (v′, ∇′) in zip(inputs(v), ∇s)
|
for (v′, ∇′) in zip(inputs(v), ∇′s)
|
||||||
∇v(v′, (value(∇′) ≠ 1 ? push!(copy(chain), ∇′) : chain), out)
|
∇graph(v′, ∇′, out)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
return out
|
return out
|
||||||
end
|
end
|
||||||
|
|
||||||
∇v(v::Vertex, chain::Vector) = ∇v(convert(IVertex, v), convert(Vector{IVertex}, chain))
|
|
||||||
|
|
||||||
∇v(v::Vertex, ∂::Vertex) = ∇v(v, [∂])
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user