recurrence working

This commit is contained in:
Mike J Innes 2017-03-30 20:05:18 +01:00
parent 298c6f252e
commit b4221f6ea6

View File

@ -109,13 +109,15 @@ import Base: @get!
# TODO: dims having its own type would be useful
executor(m::Model, input...) = @get!(m.execs, mapt(size, input), executor(m.graph, input...))
function (m::Model)(xs...)
function Flux.runmodel(m::Model, xs...)
!isdefined(m, :graph) &&
(m.graph = tograph(m.model, mapt(_ -> gensym("input"), xs)...))
@mxerr m.graph.stacks runrawbatched(xs) do xs
m.last = exec = executor(m, xs...)
exec(xs...)
end
m.last = exec = executor(m, xs...)
exec(xs...)
end
function (m::Model)(xs...)
@mxerr m.graph.stacks runrawbatched(xs -> Flux.runmodel(m, xs...), xs)
end
function Flux.back!(m::Model, Δ, xs...)
@ -127,6 +129,13 @@ end
Flux.update!(m::Model, η) = (update!(m.last, η); m)
# Recurrent Models
using Flux: Stateful, SeqModel
mxnet(m::Stateful) = Stateful(mxnet(m.model), m.state)
mxnet(m::SeqModel) = SeqModel(mxnet(m.model), m.steps)
# MX FeedForward interface
struct SoftmaxOutput