compile the forward pass again

This commit is contained in:
Mike J Innes 2017-02-24 14:38:17 +00:00
parent 9921881d47
commit bad6b2d1ae
2 changed files with 23 additions and 3 deletions

View File

@ -1,4 +1,5 @@
import DataFlow: mapconst, cse
import MacroTools: @q
export @net, @ml
@ -39,6 +40,22 @@ function build_type(T, params)
ex
end
runmodel(f, xs...) = f(xs...)
function deref_params(v)
v = map(v) do x
isa(x, Constant) && @capture(x.value, self.p_) ? Constant(:(Flux.state(self.$p))) : x
end
prewalk(v) do v
@capture(value(v), self.p_) ? vertex(:(Flux.runmodel), constant(:(self.$p)), inputs(v)...) : v
end
end
function build_forward(body, args)
iscyclic(body) && return :(error("Can't run forward pass on a cyclic graph"))
applylines(syntax(cse(deref_params(body))))
end
import Lazy: groupby
reifyparams(v::IVertex) = mapconst(x -> isa(x, Param) ? x.x : x, v)
@ -52,9 +69,11 @@ function process_type(ex)
args, body = process_func(funcs[1], pnames)
@assert length(args) == 1
self = esc(:self)
quote
@q begin
$(build_type(T, params))
$(esc(:(self::$T)))($(args...),) = interpmodel($self, $(args...))
$(esc(:(Flux.runmodel(self::$T, $(args...)) = $(build_forward(body, args)))))
($self::$(esc(T)))($(map(arg -> :($arg::Batch), args)...)) = rebatch(runmodel($self, $(map(x->:(rawbatch($x)), args)...)))
($self::$(esc(T)))($(args...)) = first($self(map(batchone, ($(args...),))...))
$(esc(:(Flux.update!(self::$T, η)))) = ($(map(p -> :(update!($self.$p, η)), pnames)...);)
$(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(mapconst(esc, makegraph(body, args))))
nothing

View File

@ -22,12 +22,13 @@ end
let a1 = Affine(10, 20), a2 = Affine(20, 15)
tlp = TLP(a1, a2)
@test tlp(xs) softmax(a2(σ(a1(xs))))
@test Flux.interpmodel(tlp, xs) softmax(a2(σ(a1(xs))))
@test Flux.infer(tlp, (1, 10)) == (1,15)
end
let tlp = TLP(Affine(10, 21), Affine(20, 15))
e = try
tlp(rand(10))
Flux.interpmodel(tlp, rand(10))
catch e
e
end