move stateful
This commit is contained in:
parent
2b85c76785
commit
d1f370a2f1
@ -1,5 +1,31 @@
|
|||||||
export unroll, unroll1
|
export unroll, unroll1
|
||||||
|
|
||||||
|
# Stateful Models
|
||||||
|
|
||||||
|
mutable struct Stateful <: Model
|
||||||
|
model
|
||||||
|
states::Vector{Any}
|
||||||
|
istate::Vector{Any}
|
||||||
|
ostate::Vector{Any}
|
||||||
|
end
|
||||||
|
|
||||||
|
Stateful(model, ss) = Stateful(model, ss, state.(ss), state.(ss))
|
||||||
|
|
||||||
|
function (m::Stateful)(x)
|
||||||
|
m.istate = m.ostate
|
||||||
|
state, y = m.model((m.istate...,), x)
|
||||||
|
m.ostate = collect(state)
|
||||||
|
return y
|
||||||
|
end
|
||||||
|
|
||||||
|
function back!(m::Stateful, Δ, x)
|
||||||
|
back!(m.model, ((zeros.(m.ostate)...,), Δ), (m.istate...,), x)[2:end]
|
||||||
|
end
|
||||||
|
|
||||||
|
update!(m::Stateful, η) = update!(m.model, η)
|
||||||
|
|
||||||
|
# Recurrent Graphs
|
||||||
|
|
||||||
struct Offset
|
struct Offset
|
||||||
name::Symbol
|
name::Symbol
|
||||||
n::Int
|
n::Int
|
||||||
|
@ -37,27 +37,3 @@ macro Chain(x, xs...)
|
|||||||
c
|
c
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
# Stateful Models
|
|
||||||
|
|
||||||
mutable struct Stateful <: Model
|
|
||||||
model
|
|
||||||
states::Vector{Any}
|
|
||||||
istate::Vector{Any}
|
|
||||||
ostate::Vector{Any}
|
|
||||||
end
|
|
||||||
|
|
||||||
Stateful(model, ss) = Stateful(model, ss, state.(ss), state.(ss))
|
|
||||||
|
|
||||||
function (m::Stateful)(x)
|
|
||||||
m.istate = m.ostate
|
|
||||||
state, y = m.model((m.istate...,), x)
|
|
||||||
m.ostate = collect(state)
|
|
||||||
return y
|
|
||||||
end
|
|
||||||
|
|
||||||
function back!(m::Stateful, Δ, x)
|
|
||||||
back!(m.model, ((zeros.(m.ostate)...,), Δ), (m.istate...,), x)[2:end]
|
|
||||||
end
|
|
||||||
|
|
||||||
update!(m::Stateful, η) = update!(m.model, η)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user