move stateful
This commit is contained in:
parent
2b85c76785
commit
d1f370a2f1
@ -1,5 +1,31 @@
|
||||
export unroll, unroll1
|
||||
|
||||
# Stateful Models
|
||||
|
||||
mutable struct Stateful <: Model
|
||||
model
|
||||
states::Vector{Any}
|
||||
istate::Vector{Any}
|
||||
ostate::Vector{Any}
|
||||
end
|
||||
|
||||
Stateful(model, ss) = Stateful(model, ss, state.(ss), state.(ss))
|
||||
|
||||
function (m::Stateful)(x)
|
||||
m.istate = m.ostate
|
||||
state, y = m.model((m.istate...,), x)
|
||||
m.ostate = collect(state)
|
||||
return y
|
||||
end
|
||||
|
||||
function back!(m::Stateful, Δ, x)
|
||||
back!(m.model, ((zeros.(m.ostate)...,), Δ), (m.istate...,), x)[2:end]
|
||||
end
|
||||
|
||||
update!(m::Stateful, η) = update!(m.model, η)
|
||||
|
||||
# Recurrent Graphs
|
||||
|
||||
struct Offset
|
||||
name::Symbol
|
||||
n::Int
|
||||
|
@ -37,27 +37,3 @@ macro Chain(x, xs...)
|
||||
c
|
||||
end
|
||||
end
|
||||
|
||||
# Stateful Models
|
||||
|
||||
mutable struct Stateful <: Model
|
||||
model
|
||||
states::Vector{Any}
|
||||
istate::Vector{Any}
|
||||
ostate::Vector{Any}
|
||||
end
|
||||
|
||||
Stateful(model, ss) = Stateful(model, ss, state.(ss), state.(ss))
|
||||
|
||||
function (m::Stateful)(x)
|
||||
m.istate = m.ostate
|
||||
state, y = m.model((m.istate...,), x)
|
||||
m.ostate = collect(state)
|
||||
return y
|
||||
end
|
||||
|
||||
function back!(m::Stateful, Δ, x)
|
||||
back!(m.model, ((zeros.(m.ostate)...,), Δ), (m.istate...,), x)[2:end]
|
||||
end
|
||||
|
||||
update!(m::Stateful, η) = update!(m.model, η)
|
||||
|
Loading…
Reference in New Issue
Block a user