reorganise recurrent stuff
This commit is contained in:
parent
c025cddc73
commit
e04dcbd460
11
src/Flux.jl
11
src/Flux.jl
@ -14,26 +14,27 @@ using Juno: Tree, Row
|
||||
|
||||
include("utils.jl")
|
||||
|
||||
include("model.jl")
|
||||
|
||||
include("dims/catmat.jl")
|
||||
include("dims/batching.jl")
|
||||
include("dims/seq.jl")
|
||||
|
||||
include("model.jl")
|
||||
include("data.jl")
|
||||
include("training.jl")
|
||||
|
||||
include("compiler/code.jl")
|
||||
include("compiler/loops.jl")
|
||||
include("compiler/interp.jl")
|
||||
include("compiler/shape.jl")
|
||||
|
||||
include("layers/control.jl")
|
||||
include("layers/affine.jl")
|
||||
include("layers/activation.jl")
|
||||
include("layers/cost.jl")
|
||||
include("layers/recurrent.jl")
|
||||
include("layers/chain.jl")
|
||||
include("layers/shims.jl")
|
||||
|
||||
include("backend/backend.jl")
|
||||
|
||||
include("data.jl")
|
||||
include("training.jl")
|
||||
|
||||
end # module
|
||||
|
@ -26,3 +26,25 @@ function rebatchseq(xs)
|
||||
B = Array{eltype(xs),dims+2}
|
||||
Batch{Seq{T,S},B}(xs)
|
||||
end
|
||||
|
||||
# SeqModel wrapper layer for convenience
|
||||
|
||||
struct SeqModel <: Model
|
||||
model
|
||||
steps::Int
|
||||
end
|
||||
|
||||
runseq(f, xs::Tuple...) = f(xs...)
|
||||
runseq(f, xs::AbstractArray...) = stack(f(map(x -> (unstack(x,2)...,), xs)...), 2)
|
||||
runseq(f, xs::BatchSeq...) = rebatchseq(runseq(f, rawbatch.(xs)...))
|
||||
|
||||
function (m::SeqModel)(x)
|
||||
runseq(x) do x
|
||||
@assert length(x) == m.steps "Expected seq length $(m.steps), got $(size(x, 2))"
|
||||
m.model(x)
|
||||
end
|
||||
end
|
||||
|
||||
back!(m::SeqModel, Δ, x) = (runseq((Δ, x) -> back!(m.model, Δ, x)[1], Δ, x),)
|
||||
|
||||
update!(m::SeqModel, η) = update!(m.model, η)
|
||||
|
@ -37,3 +37,26 @@ macro Chain(x, xs...)
|
||||
c
|
||||
end
|
||||
end
|
||||
|
||||
# Stateful Models
|
||||
|
||||
mutable struct Stateful <: Model
|
||||
model
|
||||
istate::Vector{Any}
|
||||
ostate::Vector{Any}
|
||||
end
|
||||
|
||||
Stateful(model, state) = Stateful(model, state, state)
|
||||
|
||||
function (m::Stateful)(x)
|
||||
m.istate = m.ostate
|
||||
state, y = m.model((m.istate...,), x)
|
||||
m.ostate = collect(state)
|
||||
return y
|
||||
end
|
||||
|
||||
function back!(m::Stateful, Δ, x)
|
||||
back!(m.model, ((zeros.(m.ostate)...,), Δ), (m.istate...,), x)[2:end]
|
||||
end
|
||||
|
||||
update!(m::Stateful, η) = update!(m.model, η)
|
46
src/model.jl
46
src/model.jl
@ -102,49 +102,3 @@ end
|
||||
(m::Capacitor)(xs...) = interpmodel(m, xs...)
|
||||
|
||||
graph(cap::Capacitor) = cap.graph
|
||||
|
||||
# Recurrent Models
|
||||
|
||||
mutable struct Stateful <: Model
|
||||
model
|
||||
istate::Vector{Any}
|
||||
ostate::Vector{Any}
|
||||
end
|
||||
|
||||
Stateful(model, state) = Stateful(model, state, state)
|
||||
|
||||
function (m::Stateful)(x)
|
||||
m.istate = m.ostate
|
||||
state, y = m.model((m.istate...,), x)
|
||||
m.ostate = collect(state)
|
||||
return y
|
||||
end
|
||||
|
||||
function back!(m::Stateful, Δ, x)
|
||||
back!(m.model, ((zeros.(m.ostate)...,), Δ), (m.istate...,), x)[2:end]
|
||||
end
|
||||
|
||||
update!(m::Stateful, η) = update!(m.model, η)
|
||||
|
||||
stateless(m) = m
|
||||
stateless(m::Stateful) = m.model
|
||||
|
||||
struct SeqModel <: Model
|
||||
model
|
||||
steps::Int
|
||||
end
|
||||
|
||||
runseq(f, xs::Tuple...) = f(xs...)
|
||||
runseq(f, xs::AbstractArray...) = stack(f(map(x -> (unstack(x,2)...,), xs)...), 2)
|
||||
runseq(f, xs::BatchSeq...) = rebatchseq(runseq(f, rawbatch.(xs)...))
|
||||
|
||||
function (m::SeqModel)(x)
|
||||
runseq(x) do x
|
||||
@assert length(x) == m.steps "Expected seq length $(m.steps), got $(size(x, 2))"
|
||||
m.model(x)
|
||||
end
|
||||
end
|
||||
|
||||
back!(m::SeqModel, Δ, x) = (runseq((Δ, x) -> back!(m.model, Δ, x)[1], Δ, x),)
|
||||
|
||||
update!(m::SeqModel, η) = update!(m.model, η)
|
||||
|
Loading…
Reference in New Issue
Block a user