simplify things

This commit is contained in:
Mike J Innes 2016-05-16 22:04:29 +01:00
parent eaa77cc5a6
commit 2706beaf3f

View File

@ -1,35 +1,24 @@
import Flow: isconstant, il, dl, cse, prewalk, graphm, syntax import Flow: isconstant, il, dl, cse, prewalk, graphm, syntax, @v
vertex(a...) = IVertex{Any}(a...) vertex(a...) = IVertex{Any}(a...)
(f, a::IVertex) = ∇graph(f, , a) = (@v( .* ∇₁(f)(a)),)
(v(∇₁(f), a),)
(::typeof(+), a::IVertex, b::IVertex) = ∇graph(::typeof(+), , a, b) = ,
v(1), v(1)
(::typeof(-), a::IVertex, b::IVertex) = ∇graph(::typeof(-), , a, b) = , @v(-)
v(1), v(-1)
(::typeof(*), a::IVertex, b::IVertex) = ∇graph(::typeof(*), , a, b) = map(x->@v( * transpose(x)), (b, a))
v(transpose, b), v(transpose, a)
function v(v::IVertex, chain::Vector{IVertex}, out = d()) function graph(v::IVertex, , out = d())
if isconstant(v) if isconstant(v)
@assert !haskey(out, value(v)) @assert !haskey(out, value(v))
out[value(v)] = length(chain) == 1 ? out[value(v)] =
first(chain) :
foldl((x, y) -> vertex(*, x, y), chain)
else else
s = (value(v), inputs(v)...) s = ∇graph(value(v), , inputs(v)...)
for (v, ∇′) in zip(inputs(v), s) for (v, ∇′) in zip(inputs(v), s)
v(v, (value(∇′) 1 ? push!(copy(chain), ∇′) : chain), out) graph(v, ∇′, out)
end end
end end
return out return out
end end
∇v(v::Vertex, chain::Vector) = ∇v(convert(IVertex, v), convert(Vector{IVertex}, chain))
∇v(v::Vertex, ::Vertex) = ∇v(v, [])