reorganise recurrent stuff

This commit is contained in:
Mike J Innes 2017-05-04 10:45:44 +01:00
parent c025cddc73
commit e04dcbd460
4 changed files with 51 additions and 51 deletions

View File

@ -14,26 +14,27 @@ using Juno: Tree, Row
include("utils.jl")
include("model.jl")
include("dims/catmat.jl")
include("dims/batching.jl")
include("dims/seq.jl")
include("model.jl")
include("data.jl")
include("training.jl")
include("compiler/code.jl")
include("compiler/loops.jl")
include("compiler/interp.jl")
include("compiler/shape.jl")
include("layers/control.jl")
include("layers/affine.jl")
include("layers/activation.jl")
include("layers/cost.jl")
include("layers/recurrent.jl")
include("layers/chain.jl")
include("layers/shims.jl")
include("backend/backend.jl")
include("data.jl")
include("training.jl")
end # module

View File

@ -26,3 +26,25 @@ function rebatchseq(xs)
B = Array{eltype(xs),dims+2}
Batch{Seq{T,S},B}(xs)
end
# SeqModel wrapper layer for convenience
struct SeqModel <: Model
model
steps::Int
end
runseq(f, xs::Tuple...) = f(xs...)
runseq(f, xs::AbstractArray...) = stack(f(map(x -> (unstack(x,2)...,), xs)...), 2)
runseq(f, xs::BatchSeq...) = rebatchseq(runseq(f, rawbatch.(xs)...))
function (m::SeqModel)(x)
runseq(x) do x
@assert length(x) == m.steps "Expected seq length $(m.steps), got $(size(x, 2))"
m.model(x)
end
end
back!(m::SeqModel, Δ, x) = (runseq((Δ, x) -> back!(m.model, Δ, x)[1], Δ, x),)
update!(m::SeqModel, η) = update!(m.model, η)

View File

@ -37,3 +37,26 @@ macro Chain(x, xs...)
c
end
end
# Stateful Models
mutable struct Stateful <: Model
model
istate::Vector{Any}
ostate::Vector{Any}
end
Stateful(model, state) = Stateful(model, state, state)
function (m::Stateful)(x)
m.istate = m.ostate
state, y = m.model((m.istate...,), x)
m.ostate = collect(state)
return y
end
function back!(m::Stateful, Δ, x)
back!(m.model, ((zeros.(m.ostate)...,), Δ), (m.istate...,), x)[2:end]
end
update!(m::Stateful, η) = update!(m.model, η)

View File

@ -102,49 +102,3 @@ end
(m::Capacitor)(xs...) = interpmodel(m, xs...)
graph(cap::Capacitor) = cap.graph
# Recurrent Models
mutable struct Stateful <: Model
model
istate::Vector{Any}
ostate::Vector{Any}
end
Stateful(model, state) = Stateful(model, state, state)
function (m::Stateful)(x)
m.istate = m.ostate
state, y = m.model((m.istate...,), x)
m.ostate = collect(state)
return y
end
function back!(m::Stateful, Δ, x)
back!(m.model, ((zeros.(m.ostate)...,), Δ), (m.istate...,), x)[2:end]
end
update!(m::Stateful, η) = update!(m.model, η)
stateless(m) = m
stateless(m::Stateful) = m.model
struct SeqModel <: Model
model
steps::Int
end
runseq(f, xs::Tuple...) = f(xs...)
runseq(f, xs::AbstractArray...) = stack(f(map(x -> (unstack(x,2)...,), xs)...), 2)
runseq(f, xs::BatchSeq...) = rebatchseq(runseq(f, rawbatch.(xs)...))
function (m::SeqModel)(x)
runseq(x) do x
@assert length(x) == m.steps "Expected seq length $(m.steps), got $(size(x, 2))"
m.model(x)
end
end
back!(m::SeqModel, Δ, x) = (runseq((Δ, x) -> back!(m.model, Δ, x)[1], Δ, x),)
update!(m::SeqModel, η) = update!(m.model, η)