compile the forward pass again
This commit is contained in:
parent
9921881d47
commit
bad6b2d1ae
|
@ -1,4 +1,5 @@
|
|||
import DataFlow: mapconst, cse
|
||||
import MacroTools: @q
|
||||
|
||||
export @net, @ml
|
||||
|
||||
|
@ -39,6 +40,22 @@ function build_type(T, params)
|
|||
ex
|
||||
end
|
||||
|
||||
runmodel(f, xs...) = f(xs...)
|
||||
|
||||
function deref_params(v)
|
||||
v = map(v) do x
|
||||
isa(x, Constant) && @capture(x.value, self.p_) ? Constant(:(Flux.state(self.$p))) : x
|
||||
end
|
||||
prewalk(v) do v
|
||||
@capture(value(v), self.p_) ? vertex(:(Flux.runmodel), constant(:(self.$p)), inputs(v)...) : v
|
||||
end
|
||||
end
|
||||
|
||||
function build_forward(body, args)
|
||||
iscyclic(body) && return :(error("Can't run forward pass on a cyclic graph"))
|
||||
applylines(syntax(cse(deref_params(body))))
|
||||
end
|
||||
|
||||
import Lazy: groupby
|
||||
|
||||
reifyparams(v::IVertex) = mapconst(x -> isa(x, Param) ? x.x : x, v)
|
||||
|
@ -52,9 +69,11 @@ function process_type(ex)
|
|||
args, body = process_func(funcs[1], pnames)
|
||||
@assert length(args) == 1
|
||||
self = esc(:self)
|
||||
quote
|
||||
@q begin
|
||||
$(build_type(T, params))
|
||||
$(esc(:(self::$T)))($(args...),) = interpmodel($self, $(args...))
|
||||
$(esc(:(Flux.runmodel(self::$T, $(args...)) = $(build_forward(body, args)))))
|
||||
($self::$(esc(T)))($(map(arg -> :($arg::Batch), args)...)) = rebatch(runmodel($self, $(map(x->:(rawbatch($x)), args)...)))
|
||||
($self::$(esc(T)))($(args...)) = first($self(map(batchone, ($(args...),))...))
|
||||
$(esc(:(Flux.update!(self::$T, η)))) = ($(map(p -> :(update!($self.$p, η)), pnames)...);)
|
||||
$(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(mapconst(esc, makegraph(body, args))))
|
||||
nothing
|
||||
|
|
|
@ -22,12 +22,13 @@ end
|
|||
let a1 = Affine(10, 20), a2 = Affine(20, 15)
|
||||
tlp = TLP(a1, a2)
|
||||
@test tlp(xs) ≈ softmax(a2(σ(a1(xs))))
|
||||
@test Flux.interpmodel(tlp, xs) ≈ softmax(a2(σ(a1(xs))))
|
||||
@test Flux.infer(tlp, (1, 10)) == (1,15)
|
||||
end
|
||||
|
||||
let tlp = TLP(Affine(10, 21), Affine(20, 15))
|
||||
e = try
|
||||
tlp(rand(10))
|
||||
Flux.interpmodel(tlp, rand(10))
|
||||
catch e
|
||||
e
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue