use interpreter for forward pass
This commit is contained in:
parent
c654fe403a
commit
5a32c72362
@ -4,7 +4,7 @@ using MacroTools, Lazy, DataFlow, Juno
|
|||||||
using DataFlow: graphm, syntax, prewalk!, postwalk!, prewalk, postwalk,
|
using DataFlow: graphm, syntax, prewalk!, postwalk!, prewalk, postwalk,
|
||||||
iscyclic, Constant, constant, isconstant, Group, group, Split, splitnode,
|
iscyclic, Constant, constant, isconstant, Group, group, Split, splitnode,
|
||||||
detuple, value, inputs, thread!, value, inputs, Split, splitnode, inputnode,
|
detuple, value, inputs, thread!, value, inputs, Split, splitnode, inputnode,
|
||||||
spliceinputs, bumpinputs
|
spliceinputs, bumpinputs, interpret
|
||||||
using Juno: Tree, Row
|
using Juno: Tree, Row
|
||||||
|
|
||||||
# Zero Flux Given
|
# Zero Flux Given
|
||||||
|
@ -34,35 +34,10 @@ function build_type(T, params)
|
|||||||
ex
|
ex
|
||||||
end
|
end
|
||||||
|
|
||||||
function deref_params(v)
|
|
||||||
map(v) do x
|
|
||||||
isa(x, Constant) && @capture(x.value, self.p_) ? Constant(:(state(self.$p))) : x
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function build_forward(body, args)
|
|
||||||
iscyclic(body) && return :(error("Can't run forward pass on a cyclic graph"))
|
|
||||||
syntax(cse(deref_params(body)))
|
|
||||||
end
|
|
||||||
|
|
||||||
function build_backward(body, x, params = [])
|
|
||||||
iscyclic(body) && return :(error("Can't run backward pass on a cyclic graph"))
|
|
||||||
Δs = invert(body)
|
|
||||||
back = IVertex{Any}(DataFlow.Do())
|
|
||||||
for param in params
|
|
||||||
haskey(Δs, :(self.$param)) || continue
|
|
||||||
ex = Δs[:(self.$param)]
|
|
||||||
ex = deref_params(ex)
|
|
||||||
thread!(back, @vtx(accumulate!(:(self.$param), ex)))
|
|
||||||
end
|
|
||||||
ex = Δs[x]
|
|
||||||
ex = deref_params(ex)
|
|
||||||
thread!(back, @flow(tuple($ex)))
|
|
||||||
syntax(cse(back))
|
|
||||||
end
|
|
||||||
|
|
||||||
import Lazy: groupby
|
import Lazy: groupby
|
||||||
|
|
||||||
|
reifyparams(v::IVertex) = mapconst(x -> isa(x, Param) ? x.x : x, v)
|
||||||
|
|
||||||
function process_type(ex)
|
function process_type(ex)
|
||||||
@capture(ex, type T_ fs__ end)
|
@capture(ex, type T_ fs__ end)
|
||||||
@destruct [params = false || [],
|
@destruct [params = false || [],
|
||||||
@ -73,8 +48,7 @@ function process_type(ex)
|
|||||||
@assert length(args) == 1
|
@assert length(args) == 1
|
||||||
quote
|
quote
|
||||||
$(build_type(T, params))
|
$(build_type(T, params))
|
||||||
(self::$T)($(args...),) = $(build_forward(body, args))
|
(self::$T)($(args...),) = interpret(reifyparams(graph(self)), $(args...))
|
||||||
back!(self::$T, Δ, $(args...)) = $(build_backward(body, args[1], pnames))
|
|
||||||
update!(self::$T, η) = $(map(p -> :(update!(self.$p, η)), pnames)...)
|
update!(self::$T, η) = $(map(p -> :(update!(self.$p, η)), pnames)...)
|
||||||
graph(self::$T) = $(DataFlow.constructor(makegraph(body, args)))
|
graph(self::$T) = $(DataFlow.constructor(makegraph(body, args)))
|
||||||
nothing
|
nothing
|
||||||
|
Loading…
Reference in New Issue
Block a user