gradients for recurrent models
This commit is contained in:
parent
19cf3e2b62
commit
52a7199d10
@ -133,7 +133,7 @@ Flux.update!(m::Model, η) = (update!(m.last, η); m)
|
||||
|
||||
using Flux: Stateful, SeqModel
|
||||
|
||||
mxnet(m::Stateful) = Stateful(mxnet(m.model), copy(m.state))
|
||||
mxnet(m::Stateful) = Stateful(mxnet(m.model), m.istate, m.ostate)
|
||||
mxnet(m::SeqModel) = SeqModel(mxnet(m.model), m.steps)
|
||||
|
||||
# MX FeedForward interface
|
||||
|
19
src/model.jl
19
src/model.jl
@ -127,17 +127,25 @@ graph(cap::Capacitor) = cap.graph
|
||||
|
||||
# Recurrent Models
|
||||
|
||||
struct Stateful <: Model
|
||||
mutable struct Stateful <: Model
|
||||
model
|
||||
state::Vector{Any}
|
||||
istate::Vector{Any}
|
||||
ostate::Vector{Any}
|
||||
end
|
||||
|
||||
Stateful(model, state) = Stateful(model, state, state)
|
||||
|
||||
function (m::Stateful)(x)
|
||||
state, y = runmodel(m.model, (m.state...,), x)
|
||||
m.state .= state
|
||||
m.istate = m.ostate
|
||||
state, y = runmodel(m.model, (m.istate...,), x)
|
||||
m.ostate = collect(state)
|
||||
return y
|
||||
end
|
||||
|
||||
function back!(m::Stateful, Δ, x)
|
||||
back!(m.model, ((zeros.(m.ostate)...,), Δ), (m.istate...,), x)[2:end]
|
||||
end
|
||||
|
||||
stateless(m) = m
|
||||
stateless(m::Stateful) = m.model
|
||||
|
||||
@ -157,5 +165,4 @@ function (m::SeqModel)(x)
|
||||
end
|
||||
end
|
||||
|
||||
(m::SeqModel)(x::AbstractArray) = stack(m((unstack(x, 2)...,)), 2)
|
||||
(m::SeqModel)(x::BatchSeq) = rebatchseq(m(rawbatch(x)))
|
||||
back!(m::SeqModel, Δ, x) = (runseq((Δ, x) -> back!(m.model, Δ, x)[1], Δ, x),)
|
||||
|
Loading…
Reference in New Issue
Block a user