simplify things

This commit is contained in:
Mike J Innes 2016-05-16 22:04:29 +01:00
parent eaa77cc5a6
commit 2706beaf3f

View File

@ -1,35 +1,24 @@
import Flow: isconstant, il, dl, cse, prewalk, graphm, syntax
import Flow: isconstant, il, dl, cse, prewalk, graphm, syntax, @v
vertex(a...) = IVertex{Any}(a...)
(f, a::IVertex) =
(v(∇₁(f), a),)
∇graph(f, , a) = (@v( .* ∇₁(f)(a)),)
(::typeof(+), a::IVertex, b::IVertex) =
v(1), v(1)
∇graph(::typeof(+), , a, b) = ,
(::typeof(-), a::IVertex, b::IVertex) =
v(1), v(-1)
∇graph(::typeof(-), , a, b) = , @v(-)
(::typeof(*), a::IVertex, b::IVertex) =
v(transpose, b), v(transpose, a)
∇graph(::typeof(*), , a, b) = map(x->@v( * transpose(x)), (b, a))
function v(v::IVertex, chain::Vector{IVertex}, out = d())
function graph(v::IVertex, , out = d())
if isconstant(v)
@assert !haskey(out, value(v))
out[value(v)] = length(chain) == 1 ?
first(chain) :
foldl((x, y) -> vertex(*, x, y), chain)
out[value(v)] =
else
s = (value(v), inputs(v)...)
for (v, ∇′) in zip(inputs(v), s)
v(v, (value(∇′) 1 ? push!(copy(chain), ∇′) : chain), out)
s = ∇graph(value(v), , inputs(v)...)
for (v, ∇′) in zip(inputs(v), s)
graph(v, ∇′, out)
end
end
return out
end
∇v(v::Vertex, chain::Vector) = ∇v(convert(IVertex, v), convert(Vector{IVertex}, chain))
∇v(v::Vertex, ::Vertex) = ∇v(v, [])