From 10abb64f4b58d993ce2b4c48254bcaeedcca8cc2 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 2 Jun 2017 16:28:31 +0100 Subject: [PATCH] fix stateful in backends --- src/backend/mxnet/model.jl | 2 +- src/backend/tensorflow/model.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl index 7d3c7895..975f71ce 100644 --- a/src/backend/mxnet/model.jl +++ b/src/backend/mxnet/model.jl @@ -116,7 +116,7 @@ Flux.update!(m::Model, η) = (update!(m.last, η); m) using Flux: Stateful, SeqModel -mxnet(m::Stateful) = Stateful(mxnet(m.model), m.istate, m.ostate) +mxnet(m::Stateful) = Stateful(mxnet(m.model), m.states, m.istate, m.ostate) mxnet(m::SeqModel) = SeqModel(mxnet(m.model), m.steps) # MX FeedForward interface diff --git a/src/backend/tensorflow/model.jl b/src/backend/tensorflow/model.jl index f66b2d28..836861a0 100644 --- a/src/backend/tensorflow/model.jl +++ b/src/backend/tensorflow/model.jl @@ -77,5 +77,5 @@ Flux.update!(m::Model, η) = (update!(m.exec, η); m) using Flux: Stateful, SeqModel -tf(m::Stateful) = Stateful(tf(m.model), m.istate, m.ostate) +tf(m::Stateful) = Stateful(tf(m.model), m.states, m.istate, m.ostate) tf(m::SeqModel) = SeqModel(tf(m.model), m.steps)