gradients for recurrent models
This commit is contained in:
parent
19cf3e2b62
commit
52a7199d10
@ -133,7 +133,7 @@ Flux.update!(m::Model, η) = (update!(m.last, η); m)
|
|||||||
|
|
||||||
using Flux: Stateful, SeqModel
|
using Flux: Stateful, SeqModel
|
||||||
|
|
||||||
mxnet(m::Stateful) = Stateful(mxnet(m.model), copy(m.state))
|
mxnet(m::Stateful) = Stateful(mxnet(m.model), m.istate, m.ostate)
|
||||||
mxnet(m::SeqModel) = SeqModel(mxnet(m.model), m.steps)
|
mxnet(m::SeqModel) = SeqModel(mxnet(m.model), m.steps)
|
||||||
|
|
||||||
# MX FeedForward interface
|
# MX FeedForward interface
|
||||||
|
19
src/model.jl
19
src/model.jl
@ -127,17 +127,25 @@ graph(cap::Capacitor) = cap.graph
|
|||||||
|
|
||||||
# Recurrent Models
|
# Recurrent Models
|
||||||
|
|
||||||
struct Stateful <: Model
|
mutable struct Stateful <: Model
|
||||||
model
|
model
|
||||||
state::Vector{Any}
|
istate::Vector{Any}
|
||||||
|
ostate::Vector{Any}
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Stateful(model, state) = Stateful(model, state, state)
|
||||||
|
|
||||||
function (m::Stateful)(x)
|
function (m::Stateful)(x)
|
||||||
state, y = runmodel(m.model, (m.state...,), x)
|
m.istate = m.ostate
|
||||||
m.state .= state
|
state, y = runmodel(m.model, (m.istate...,), x)
|
||||||
|
m.ostate = collect(state)
|
||||||
return y
|
return y
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function back!(m::Stateful, Δ, x)
|
||||||
|
back!(m.model, ((zeros.(m.ostate)...,), Δ), (m.istate...,), x)[2:end]
|
||||||
|
end
|
||||||
|
|
||||||
stateless(m) = m
|
stateless(m) = m
|
||||||
stateless(m::Stateful) = m.model
|
stateless(m::Stateful) = m.model
|
||||||
|
|
||||||
@ -157,5 +165,4 @@ function (m::SeqModel)(x)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
(m::SeqModel)(x::AbstractArray) = stack(m((unstack(x, 2)...,)), 2)
|
back!(m::SeqModel, Δ, x) = (runseq((Δ, x) -> back!(m.model, Δ, x)[1], Δ, x),)
|
||||||
(m::SeqModel)(x::BatchSeq) = rebatchseq(m(rawbatch(x)))
|
|
||||||
|
Loading…
Reference in New Issue
Block a user