gradients for recurrent models

This commit is contained in:
Mike J Innes 2017-04-26 17:42:47 +01:00
parent 19cf3e2b62
commit 52a7199d10
2 changed files with 14 additions and 7 deletions

View File

@ -133,7 +133,7 @@ Flux.update!(m::Model, η) = (update!(m.last, η); m)
using Flux: Stateful, SeqModel
mxnet(m::Stateful) = Stateful(mxnet(m.model), copy(m.state))
mxnet(m::Stateful) = Stateful(mxnet(m.model), m.istate, m.ostate)
mxnet(m::SeqModel) = SeqModel(mxnet(m.model), m.steps)
# MX FeedForward interface

View File

@ -127,17 +127,25 @@ graph(cap::Capacitor) = cap.graph
# Recurrent Models
struct Stateful <: Model
mutable struct Stateful <: Model
model
state::Vector{Any}
istate::Vector{Any}
ostate::Vector{Any}
end
Stateful(model, state) = Stateful(model, state, state)
function (m::Stateful)(x)
state, y = runmodel(m.model, (m.state...,), x)
m.state .= state
m.istate = m.ostate
state, y = runmodel(m.model, (m.istate...,), x)
m.ostate = collect(state)
return y
end
function back!(m::Stateful, Δ, x)
back!(m.model, ((zeros.(m.ostate)...,), Δ), (m.istate...,), x)[2:end]
end
stateless(m) = m
stateless(m::Stateful) = m.model
@ -157,5 +165,4 @@ function (m::SeqModel)(x)
end
end
(m::SeqModel)(x::AbstractArray) = stack(m((unstack(x, 2)...,)), 2)
(m::SeqModel)(x::BatchSeq) = rebatchseq(m(rawbatch(x)))
back!(m::SeqModel, Δ, x) = (runseq((Δ, x) -> back!(m.model, Δ, x)[1], Δ, x),)