runmodel no longer needed

This commit is contained in:
Mike J Innes 2017-05-04 10:32:53 +01:00
parent 7e5669d2f6
commit c025cddc73
3 changed files with 4 additions and 17 deletions

View File

@ -39,15 +39,10 @@ function build_type(T, params)
ex ex
end end
runmodel(f, xs...) = f(xs...)
function deref_params(v) function deref_params(v)
v = map(v) do x map(v) do x
x isa Constant && @capture(x.value, self.p_) ? Constant(:(Flux.state(self.$p))) : x x isa Constant && @capture(x.value, self.p_) ? Constant(:(Flux.state(self.$p))) : x
end end
prewalk(v) do v
@capture(value(v), self.p_) ? vertex(:(Flux.runmodel), constant(:(self.$p)), inputs(v)...) : v
end
end end
function build_forward(body, args) function build_forward(body, args)

View File

@ -46,16 +46,6 @@ methods as necessary.
""" """
graph(m) = nothing graph(m) = nothing
"""
`runmodel(m, ...)` is like `m(...)`, i.e. it runs the forward pass. However,
unlike direct calling, it does not try to apply batching and simply uses the
inputs directly.
This function should be considered an implementation detail; it will be
eventually be replaced by a non-hacky way of doing batching.
"""
function runmodel end
# Model parameters # Model parameters
# TODO: should be AbstractArray? # TODO: should be AbstractArray?
@ -125,7 +115,7 @@ Stateful(model, state) = Stateful(model, state, state)
function (m::Stateful)(x) function (m::Stateful)(x)
m.istate = m.ostate m.istate = m.ostate
state, y = runmodel(m.model, (m.istate...,), x) state, y = m.model((m.istate...,), x)
m.ostate = collect(state) m.ostate = collect(state)
return y return y
end end

View File

@ -1,5 +1,7 @@
export AArray, unsqueeze export AArray, unsqueeze
call(f, xs...) = f(xs...)
# Arrays # Arrays
const AArray = AbstractArray const AArray = AbstractArray