runmodel no longer needed
This commit is contained in:
parent
7e5669d2f6
commit
c025cddc73
@ -39,15 +39,10 @@ function build_type(T, params)
|
||||
ex
|
||||
end
|
||||
|
||||
runmodel(f, xs...) = f(xs...)
|
||||
|
||||
function deref_params(v)
|
||||
v = map(v) do x
|
||||
map(v) do x
|
||||
x isa Constant && @capture(x.value, self.p_) ? Constant(:(Flux.state(self.$p))) : x
|
||||
end
|
||||
prewalk(v) do v
|
||||
@capture(value(v), self.p_) ? vertex(:(Flux.runmodel), constant(:(self.$p)), inputs(v)...) : v
|
||||
end
|
||||
end
|
||||
|
||||
function build_forward(body, args)
|
||||
|
12
src/model.jl
12
src/model.jl
@ -46,16 +46,6 @@ methods as necessary.
|
||||
"""
|
||||
graph(m) = nothing
|
||||
|
||||
"""
|
||||
`runmodel(m, ...)` is like `m(...)`, i.e. it runs the forward pass. However,
|
||||
unlike direct calling, it does not try to apply batching and simply uses the
|
||||
inputs directly.
|
||||
|
||||
This function should be considered an implementation detail; it will be
|
||||
eventually be replaced by a non-hacky way of doing batching.
|
||||
"""
|
||||
function runmodel end
|
||||
|
||||
# Model parameters
|
||||
|
||||
# TODO: should be AbstractArray?
|
||||
@ -125,7 +115,7 @@ Stateful(model, state) = Stateful(model, state, state)
|
||||
|
||||
function (m::Stateful)(x)
|
||||
m.istate = m.ostate
|
||||
state, y = runmodel(m.model, (m.istate...,), x)
|
||||
state, y = m.model((m.istate...,), x)
|
||||
m.ostate = collect(state)
|
||||
return y
|
||||
end
|
||||
|
@ -1,5 +1,7 @@
|
||||
export AArray, unsqueeze
|
||||
|
||||
call(f, xs...) = f(xs...)
|
||||
|
||||
# Arrays
|
||||
|
||||
const AArray = AbstractArray
|
||||
|
Loading…
Reference in New Issue
Block a user