runmodel no longer needed
This commit is contained in:
parent
7e5669d2f6
commit
c025cddc73
@ -39,15 +39,10 @@ function build_type(T, params)
|
|||||||
ex
|
ex
|
||||||
end
|
end
|
||||||
|
|
||||||
runmodel(f, xs...) = f(xs...)
|
|
||||||
|
|
||||||
function deref_params(v)
|
function deref_params(v)
|
||||||
v = map(v) do x
|
map(v) do x
|
||||||
x isa Constant && @capture(x.value, self.p_) ? Constant(:(Flux.state(self.$p))) : x
|
x isa Constant && @capture(x.value, self.p_) ? Constant(:(Flux.state(self.$p))) : x
|
||||||
end
|
end
|
||||||
prewalk(v) do v
|
|
||||||
@capture(value(v), self.p_) ? vertex(:(Flux.runmodel), constant(:(self.$p)), inputs(v)...) : v
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
function build_forward(body, args)
|
function build_forward(body, args)
|
||||||
|
12
src/model.jl
12
src/model.jl
@ -46,16 +46,6 @@ methods as necessary.
|
|||||||
"""
|
"""
|
||||||
graph(m) = nothing
|
graph(m) = nothing
|
||||||
|
|
||||||
"""
|
|
||||||
`runmodel(m, ...)` is like `m(...)`, i.e. it runs the forward pass. However,
|
|
||||||
unlike direct calling, it does not try to apply batching and simply uses the
|
|
||||||
inputs directly.
|
|
||||||
|
|
||||||
This function should be considered an implementation detail; it will be
|
|
||||||
eventually be replaced by a non-hacky way of doing batching.
|
|
||||||
"""
|
|
||||||
function runmodel end
|
|
||||||
|
|
||||||
# Model parameters
|
# Model parameters
|
||||||
|
|
||||||
# TODO: should be AbstractArray?
|
# TODO: should be AbstractArray?
|
||||||
@ -125,7 +115,7 @@ Stateful(model, state) = Stateful(model, state, state)
|
|||||||
|
|
||||||
function (m::Stateful)(x)
|
function (m::Stateful)(x)
|
||||||
m.istate = m.ostate
|
m.istate = m.ostate
|
||||||
state, y = runmodel(m.model, (m.istate...,), x)
|
state, y = m.model((m.istate...,), x)
|
||||||
m.ostate = collect(state)
|
m.ostate = collect(state)
|
||||||
return y
|
return y
|
||||||
end
|
end
|
||||||
|
@ -1,5 +1,7 @@
|
|||||||
export AArray, unsqueeze
|
export AArray, unsqueeze
|
||||||
|
|
||||||
|
call(f, xs...) = f(xs...)
|
||||||
|
|
||||||
# Arrays
|
# Arrays
|
||||||
|
|
||||||
const AArray = AbstractArray
|
const AArray = AbstractArray
|
||||||
|
Loading…
Reference in New Issue
Block a user