use interpreter for forward pass

This commit is contained in:
Mike J Innes 2016-11-13 20:46:35 +00:00
parent c654fe403a
commit 5a32c72362
2 changed files with 4 additions and 30 deletions

View File

@ -4,7 +4,7 @@ using MacroTools, Lazy, DataFlow, Juno
using DataFlow: graphm, syntax, prewalk!, postwalk!, prewalk, postwalk,
iscyclic, Constant, constant, isconstant, Group, group, Split, splitnode,
detuple, value, inputs, thread!, value, inputs, Split, splitnode, inputnode,
spliceinputs, bumpinputs
spliceinputs, bumpinputs, interpret
using Juno: Tree, Row
# Zero Flux Given

View File

@ -34,35 +34,10 @@ function build_type(T, params)
ex
end
function deref_params(v)
map(v) do x
isa(x, Constant) && @capture(x.value, self.p_) ? Constant(:(state(self.$p))) : x
end
end
function build_forward(body, args)
iscyclic(body) && return :(error("Can't run forward pass on a cyclic graph"))
syntax(cse(deref_params(body)))
end
function build_backward(body, x, params = [])
iscyclic(body) && return :(error("Can't run backward pass on a cyclic graph"))
Δs = invert(body)
back = IVertex{Any}(DataFlow.Do())
for param in params
haskey(Δs, :(self.$param)) || continue
ex = Δs[:(self.$param)]
ex = deref_params(ex)
thread!(back, @vtx(accumulate!(:(self.$param), ex)))
end
ex = Δs[x]
ex = deref_params(ex)
thread!(back, @flow(tuple($ex)))
syntax(cse(back))
end
import Lazy: groupby
reifyparams(v::IVertex) = mapconst(x -> isa(x, Param) ? x.x : x, v)
function process_type(ex)
@capture(ex, type T_ fs__ end)
@destruct [params = false || [],
@ -73,8 +48,7 @@ function process_type(ex)
@assert length(args) == 1
quote
$(build_type(T, params))
(self::$T)($(args...),) = $(build_forward(body, args))
back!(self::$T, Δ, $(args...)) = $(build_backward(body, args[1], pnames))
(self::$T)($(args...),) = interpret(reifyparams(graph(self)), $(args...))
update!(self::$T, η) = $(map(p -> :(update!(self.$p, η)), pnames)...)
graph(self::$T) = $(DataFlow.constructor(makegraph(body, args)))
nothing