From d1f370a2f1d121d553c4a242b47c4b2fc64c6c59 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 2 Jun 2017 16:02:47 +0100 Subject: [PATCH] move stateful --- src/compiler/loops.jl | 26 ++++++++++++++++++++++++++ src/layers/control.jl | 24 ------------------------ 2 files changed, 26 insertions(+), 24 deletions(-) diff --git a/src/compiler/loops.jl b/src/compiler/loops.jl index 3785d1a2..6b38bf13 100644 --- a/src/compiler/loops.jl +++ b/src/compiler/loops.jl @@ -1,5 +1,31 @@ export unroll, unroll1 +# Stateful Models + +mutable struct Stateful <: Model + model + states::Vector{Any} + istate::Vector{Any} + ostate::Vector{Any} +end + +Stateful(model, ss) = Stateful(model, ss, state.(ss), state.(ss)) + +function (m::Stateful)(x) + m.istate = m.ostate + state, y = m.model((m.istate...,), x) + m.ostate = collect(state) + return y +end + +function back!(m::Stateful, Δ, x) + back!(m.model, ((zeros.(m.ostate)...,), Δ), (m.istate...,), x)[2:end] +end + +update!(m::Stateful, η) = update!(m.model, η) + +# Recurrent Graphs + struct Offset name::Symbol n::Int diff --git a/src/layers/control.jl b/src/layers/control.jl index 7cc907f7..839bb58c 100644 --- a/src/layers/control.jl +++ b/src/layers/control.jl @@ -37,27 +37,3 @@ macro Chain(x, xs...) c end end - -# Stateful Models - -mutable struct Stateful <: Model - model - states::Vector{Any} - istate::Vector{Any} - ostate::Vector{Any} -end - -Stateful(model, ss) = Stateful(model, ss, state.(ss), state.(ss)) - -function (m::Stateful)(x) - m.istate = m.ostate - state, y = m.model((m.istate...,), x) - m.ostate = collect(state) - return y -end - -function back!(m::Stateful, Δ, x) - back!(m.model, ((zeros.(m.ostate)...,), Δ), (m.istate...,), x)[2:end] -end - -update!(m::Stateful, η) = update!(m.model, η)