From b4221f6ea62ca989643a68ed7d239dc67a625843 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Thu, 30 Mar 2017 20:05:18 +0100 Subject: [PATCH] recurrence working --- src/backend/mxnet/model.jl | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl index 5e0576e4..9d7a310e 100644 --- a/src/backend/mxnet/model.jl +++ b/src/backend/mxnet/model.jl @@ -109,13 +109,15 @@ import Base: @get! # TODO: dims having its own type would be useful executor(m::Model, input...) = @get!(m.execs, mapt(size, input), executor(m.graph, input...)) -function (m::Model)(xs...) +function Flux.runmodel(m::Model, xs...) !isdefined(m, :graph) && (m.graph = tograph(m.model, mapt(_ -> gensym("input"), xs)...)) - @mxerr m.graph.stacks runrawbatched(xs) do xs - m.last = exec = executor(m, xs...) - exec(xs...) - end + m.last = exec = executor(m, xs...) + exec(xs...) +end + +function (m::Model)(xs...) + @mxerr m.graph.stacks runrawbatched(xs -> Flux.runmodel(m, xs...), xs) end function Flux.back!(m::Model, Δ, xs...) @@ -127,6 +129,13 @@ end Flux.update!(m::Model, η) = (update!(m.last, η); m) +# Recurrent Models + +using Flux: Stateful, SeqModel + +mxnet(m::Stateful) = Stateful(mxnet(m.model), m.state) +mxnet(m::SeqModel) = SeqModel(mxnet(m.model), m.steps) + # MX FeedForward interface struct SoftmaxOutput