updates for Flow
This commit is contained in:
parent
9986a1c163
commit
bd8c935aef
@ -1,7 +1,9 @@
|
|||||||
|
import Flow: mapconst
|
||||||
|
|
||||||
function process_func(ex, params)
|
function process_func(ex, params)
|
||||||
@capture(shortdef(ex), (args__,) -> body_)
|
@capture(shortdef(ex), (args__,) -> body_)
|
||||||
body = il(graphm(body))
|
body = il(graphm(body))
|
||||||
body = map(x -> x in params ? :(self.$x) : x, body)
|
body = mapconst(x -> x in params ? :(self.$x) : x, body)
|
||||||
return args, body
|
return args, body
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -28,7 +30,7 @@ function build_backward(body, x, params)
|
|||||||
k = symbol("Δ", param)
|
k = symbol("Δ", param)
|
||||||
ksym = Expr(:quote, k)
|
ksym = Expr(:quote, k)
|
||||||
ex = Δs[:(self.$param)]
|
ex = Δs[:(self.$param)]
|
||||||
thread!(back, @v(setfield!(:self, ksym, :(self.$k) + ex)))
|
thread!(back, @vertex(setfield!(:self, ksym, :(self.$k) + ex)))
|
||||||
end
|
end
|
||||||
ex = Δs[x]
|
ex = Δs[x]
|
||||||
thread!(back, @flow(tuple($ex)))
|
thread!(back, @flow(tuple($ex)))
|
||||||
|
@ -1,6 +1,9 @@
|
|||||||
import Flow: isconstant, il, dl, cse, prewalk, graphm, syntax, @v
|
import Flow: isconstant, il, dl, cse, prewalk, graphm, syntax, @vertex
|
||||||
|
|
||||||
vertex(a...) = IVertex{Any}(a...)
|
vertex(a...) = IVertex{Any}(a...)
|
||||||
|
vertex(v::Vertex) = convert(IVertex{Any}, v)
|
||||||
|
constant(x) = vertex(Flow.Constant(x))
|
||||||
|
constant(x::Vertex) = vertex(x)
|
||||||
|
|
||||||
addΔ(a, b) = vertex(:+, a, b)
|
addΔ(a, b) = vertex(:+, a, b)
|
||||||
|
|
||||||
@ -11,15 +14,15 @@ symbolic[:+] = (Δ, args...) -> map(_->Δ, args)
|
|||||||
|
|
||||||
function ∇v(v::Vertex, Δ)
|
function ∇v(v::Vertex, Δ)
|
||||||
haskey(symbolic, value(v)) && return symbolic[value(v)](Δ, inputs(v)...)
|
haskey(symbolic, value(v)) && return symbolic[value(v)](Δ, inputs(v)...)
|
||||||
Δ = vertex(:back!, vertex(value(v)), Δ, inputs(v)...)
|
Δ = vertex(:back!, constant(value(v)), constant(Δ), inputs(v)...)
|
||||||
map(i -> @flow(getindex($Δ, $i)), 1:Flow.nin(v))
|
map(i -> @flow(getindex($Δ, $i)), 1:Flow.nin(v))
|
||||||
end
|
end
|
||||||
|
|
||||||
function invert(v::IVertex, Δ = vertex(:Δ), out = d())
|
function invert(v::IVertex, Δ = constant(:Δ), out = d())
|
||||||
@assert !iscyclic(v)
|
@assert !iscyclic(v)
|
||||||
if isconstant(v)
|
if isconstant(v)
|
||||||
@assert !haskey(out, value(v))
|
@assert !haskey(out, value(v))
|
||||||
out[value(v)] = il(Δ)
|
out[value(v).value] = il(Δ)
|
||||||
else
|
else
|
||||||
Δ′s = ∇v(v, Δ)
|
Δ′s = ∇v(v, Δ)
|
||||||
for (v′, Δ′) in zip(inputs(v), Δ′s)
|
for (v′, Δ′) in zip(inputs(v), Δ′s)
|
||||||
|
Loading…
Reference in New Issue
Block a user