From c6eb901f1d5d59fa4b9f37999cbb2d2073f20bb1 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Thu, 18 Aug 2016 22:06:12 +0100 Subject: [PATCH] updates for Flow --- src/compiler/code.jl | 6 ++++-- src/compiler/diff.jl | 11 +++++++---- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/compiler/code.jl b/src/compiler/code.jl index 335c7b71..27a35c63 100644 --- a/src/compiler/code.jl +++ b/src/compiler/code.jl @@ -1,7 +1,9 @@ +import Flow: mapconst + function process_func(ex, params) @capture(shortdef(ex), (args__,) -> body_) body = il(graphm(body)) - body = map(x -> x in params ? :(self.$x) : x, body) + body = mapconst(x -> x in params ? :(self.$x) : x, body) return args, body end @@ -28,7 +30,7 @@ function build_backward(body, x, params) k = symbol("Δ", param) ksym = Expr(:quote, k) ex = Δs[:(self.$param)] - thread!(back, @v(setfield!(:self, ksym, :(self.$k) + ex))) + thread!(back, @vertex(setfield!(:self, ksym, :(self.$k) + ex))) end ex = Δs[x] thread!(back, @flow(tuple($ex))) diff --git a/src/compiler/diff.jl b/src/compiler/diff.jl index dfc4b960..efab7943 100644 --- a/src/compiler/diff.jl +++ b/src/compiler/diff.jl @@ -1,6 +1,9 @@ -import Flow: isconstant, il, dl, cse, prewalk, graphm, syntax, @v +import Flow: isconstant, il, dl, cse, prewalk, graphm, syntax, @vertex vertex(a...) = IVertex{Any}(a...) +vertex(v::Vertex) = convert(IVertex{Any}, v) +constant(x) = vertex(Flow.Constant(x)) +constant(x::Vertex) = vertex(x) addΔ(a, b) = vertex(:+, a, b) @@ -11,15 +14,15 @@ symbolic[:+] = (Δ, args...) -> map(_->Δ, args) function ∇v(v::Vertex, Δ) haskey(symbolic, value(v)) && return symbolic[value(v)](Δ, inputs(v)...) - Δ = vertex(:back!, vertex(value(v)), Δ, inputs(v)...) + Δ = vertex(:back!, constant(value(v)), constant(Δ), inputs(v)...) map(i -> @flow(getindex($Δ, $i)), 1:Flow.nin(v)) end -function invert(v::IVertex, Δ = vertex(:Δ), out = d()) +function invert(v::IVertex, Δ = constant(:Δ), out = d()) @assert !iscyclic(v) if isconstant(v) @assert !haskey(out, value(v)) - out[value(v)] = il(Δ) + out[value(v).value] = il(Δ) else Δ′s = ∇v(v, Δ) for (v′, Δ′) in zip(inputs(v), Δ′s)