compile the forward pass again
This commit is contained in:
parent
9921881d47
commit
bad6b2d1ae
@ -1,4 +1,5 @@
|
|||||||
import DataFlow: mapconst, cse
|
import DataFlow: mapconst, cse
|
||||||
|
import MacroTools: @q
|
||||||
|
|
||||||
export @net, @ml
|
export @net, @ml
|
||||||
|
|
||||||
@ -39,6 +40,22 @@ function build_type(T, params)
|
|||||||
ex
|
ex
|
||||||
end
|
end
|
||||||
|
|
||||||
|
runmodel(f, xs...) = f(xs...)
|
||||||
|
|
||||||
|
function deref_params(v)
|
||||||
|
v = map(v) do x
|
||||||
|
isa(x, Constant) && @capture(x.value, self.p_) ? Constant(:(Flux.state(self.$p))) : x
|
||||||
|
end
|
||||||
|
prewalk(v) do v
|
||||||
|
@capture(value(v), self.p_) ? vertex(:(Flux.runmodel), constant(:(self.$p)), inputs(v)...) : v
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
function build_forward(body, args)
|
||||||
|
iscyclic(body) && return :(error("Can't run forward pass on a cyclic graph"))
|
||||||
|
applylines(syntax(cse(deref_params(body))))
|
||||||
|
end
|
||||||
|
|
||||||
import Lazy: groupby
|
import Lazy: groupby
|
||||||
|
|
||||||
reifyparams(v::IVertex) = mapconst(x -> isa(x, Param) ? x.x : x, v)
|
reifyparams(v::IVertex) = mapconst(x -> isa(x, Param) ? x.x : x, v)
|
||||||
@ -52,9 +69,11 @@ function process_type(ex)
|
|||||||
args, body = process_func(funcs[1], pnames)
|
args, body = process_func(funcs[1], pnames)
|
||||||
@assert length(args) == 1
|
@assert length(args) == 1
|
||||||
self = esc(:self)
|
self = esc(:self)
|
||||||
quote
|
@q begin
|
||||||
$(build_type(T, params))
|
$(build_type(T, params))
|
||||||
$(esc(:(self::$T)))($(args...),) = interpmodel($self, $(args...))
|
$(esc(:(Flux.runmodel(self::$T, $(args...)) = $(build_forward(body, args)))))
|
||||||
|
($self::$(esc(T)))($(map(arg -> :($arg::Batch), args)...)) = rebatch(runmodel($self, $(map(x->:(rawbatch($x)), args)...)))
|
||||||
|
($self::$(esc(T)))($(args...)) = first($self(map(batchone, ($(args...),))...))
|
||||||
$(esc(:(Flux.update!(self::$T, η)))) = ($(map(p -> :(update!($self.$p, η)), pnames)...);)
|
$(esc(:(Flux.update!(self::$T, η)))) = ($(map(p -> :(update!($self.$p, η)), pnames)...);)
|
||||||
$(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(mapconst(esc, makegraph(body, args))))
|
$(esc(:(Flux.graph(self::$T)))) = $(DataFlow.constructor(mapconst(esc, makegraph(body, args))))
|
||||||
nothing
|
nothing
|
||||||
|
@ -22,12 +22,13 @@ end
|
|||||||
let a1 = Affine(10, 20), a2 = Affine(20, 15)
|
let a1 = Affine(10, 20), a2 = Affine(20, 15)
|
||||||
tlp = TLP(a1, a2)
|
tlp = TLP(a1, a2)
|
||||||
@test tlp(xs) ≈ softmax(a2(σ(a1(xs))))
|
@test tlp(xs) ≈ softmax(a2(σ(a1(xs))))
|
||||||
|
@test Flux.interpmodel(tlp, xs) ≈ softmax(a2(σ(a1(xs))))
|
||||||
@test Flux.infer(tlp, (1, 10)) == (1,15)
|
@test Flux.infer(tlp, (1, 10)) == (1,15)
|
||||||
end
|
end
|
||||||
|
|
||||||
let tlp = TLP(Affine(10, 21), Affine(20, 15))
|
let tlp = TLP(Affine(10, 21), Affine(20, 15))
|
||||||
e = try
|
e = try
|
||||||
tlp(rand(10))
|
Flux.interpmodel(tlp, rand(10))
|
||||||
catch e
|
catch e
|
||||||
e
|
e
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user