diff --git a/src/compiler/code.jl b/src/compiler/code.jl index 27a35c63..e44e6563 100644 --- a/src/compiler/code.jl +++ b/src/compiler/code.jl @@ -1,8 +1,8 @@ -import Flow: mapconst +import Flow: mapconst, cse function process_func(ex, params) @capture(shortdef(ex), (args__,) -> body_) - body = il(graphm(body)) + body = Flow.il(graphm(body)) body = mapconst(x -> x in params ? :(self.$x) : x, body) return args, body end @@ -30,7 +30,7 @@ function build_backward(body, x, params) k = symbol("Δ", param) ksym = Expr(:quote, k) ex = Δs[:(self.$param)] - thread!(back, @vertex(setfield!(:self, ksym, :(self.$k) + ex))) + thread!(back, @dvertex(setfield!(:self, ksym, :(self.$k) + ex))) end ex = Δs[x] thread!(back, @flow(tuple($ex))) diff --git a/src/compiler/diff.jl b/src/compiler/diff.jl index efab7943..0cf4bf9b 100644 --- a/src/compiler/diff.jl +++ b/src/compiler/diff.jl @@ -1,10 +1,3 @@ -import Flow: isconstant, il, dl, cse, prewalk, graphm, syntax, @vertex - -vertex(a...) = IVertex{Any}(a...) -vertex(v::Vertex) = convert(IVertex{Any}, v) -constant(x) = vertex(Flow.Constant(x)) -constant(x::Vertex) = vertex(x) - addΔ(a, b) = vertex(:+, a, b) # Special case a couple of operators to clean up output code @@ -18,11 +11,11 @@ function ∇v(v::Vertex, Δ) map(i -> @flow(getindex($Δ, $i)), 1:Flow.nin(v)) end -function invert(v::IVertex, Δ = constant(:Δ), out = d()) +function invert(v::IVertex, Δ = :Δ, out = d()) @assert !iscyclic(v) if isconstant(v) @assert !haskey(out, value(v)) - out[value(v).value] = il(Δ) + out[value(v).value] = constant(Δ) else Δ′s = ∇v(v, Δ) for (v′, Δ′) in zip(inputs(v), Δ′s)