move stateful

This commit is contained in:
Mike J Innes 2017-06-02 16:02:47 +01:00
parent 2b85c76785
commit d1f370a2f1
2 changed files with 26 additions and 24 deletions

View File

@ -1,5 +1,31 @@
export unroll, unroll1
# Stateful Models
mutable struct Stateful <: Model
model
states::Vector{Any}
istate::Vector{Any}
ostate::Vector{Any}
end
Stateful(model, ss) = Stateful(model, ss, state.(ss), state.(ss))
function (m::Stateful)(x)
m.istate = m.ostate
state, y = m.model((m.istate...,), x)
m.ostate = collect(state)
return y
end
function back!(m::Stateful, Δ, x)
back!(m.model, ((zeros.(m.ostate)...,), Δ), (m.istate...,), x)[2:end]
end
update!(m::Stateful, η) = update!(m.model, η)
# Recurrent Graphs
struct Offset
name::Symbol
n::Int

View File

@ -37,27 +37,3 @@ macro Chain(x, xs...)
c
end
end
# Stateful Models
mutable struct Stateful <: Model
model
states::Vector{Any}
istate::Vector{Any}
ostate::Vector{Any}
end
Stateful(model, ss) = Stateful(model, ss, state.(ss), state.(ss))
function (m::Stateful)(x)
m.istate = m.ostate
state, y = m.model((m.istate...,), x)
m.ostate = collect(state)
return y
end
function back!(m::Stateful, Δ, x)
back!(m.model, ((zeros.(m.ostate)...,), Δ), (m.istate...,), x)[2:end]
end
update!(m::Stateful, η) = update!(m.model, η)