tweak for flow api

This commit is contained in:
Mike J Innes 2016-08-18 22:31:49 +01:00
parent bd8c935aef
commit 4667c55a8a
2 changed files with 5 additions and 12 deletions

View File

@ -1,8 +1,8 @@
import Flow: mapconst
import Flow: mapconst, cse
function process_func(ex, params)
@capture(shortdef(ex), (args__,) -> body_)
body = il(graphm(body))
body = Flow.il(graphm(body))
body = mapconst(x -> x in params ? :(self.$x) : x, body)
return args, body
end
@ -30,7 +30,7 @@ function build_backward(body, x, params)
k = symbol("Δ", param)
ksym = Expr(:quote, k)
ex = Δs[:(self.$param)]
thread!(back, @vertex(setfield!(:self, ksym, :(self.$k) + ex)))
thread!(back, @dvertex(setfield!(:self, ksym, :(self.$k) + ex)))
end
ex = Δs[x]
thread!(back, @flow(tuple($ex)))

View File

@ -1,10 +1,3 @@
import Flow: isconstant, il, dl, cse, prewalk, graphm, syntax, @vertex
vertex(a...) = IVertex{Any}(a...)
vertex(v::Vertex) = convert(IVertex{Any}, v)
constant(x) = vertex(Flow.Constant(x))
constant(x::Vertex) = vertex(x)
addΔ(a, b) = vertex(:+, a, b)
# Special case a couple of operators to clean up output code
@ -18,11 +11,11 @@ function ∇v(v::Vertex, Δ)
map(i -> @flow(getindex($Δ, $i)), 1:Flow.nin(v))
end
function invert(v::IVertex, Δ = constant(), out = d())
function invert(v::IVertex, Δ = , out = d())
@assert !iscyclic(v)
if isconstant(v)
@assert !haskey(out, value(v))
out[value(v).value] = il(Δ)
out[value(v).value] = constant(Δ)
else
Δs = ∇v(v, Δ)
for (v, Δ′) in zip(inputs(v), Δs)