reorganise recurrent stuff
This commit is contained in:
parent
c025cddc73
commit
e04dcbd460
11
src/Flux.jl
11
src/Flux.jl
@ -14,26 +14,27 @@ using Juno: Tree, Row
|
|||||||
|
|
||||||
include("utils.jl")
|
include("utils.jl")
|
||||||
|
|
||||||
|
include("model.jl")
|
||||||
|
|
||||||
include("dims/catmat.jl")
|
include("dims/catmat.jl")
|
||||||
include("dims/batching.jl")
|
include("dims/batching.jl")
|
||||||
include("dims/seq.jl")
|
include("dims/seq.jl")
|
||||||
|
|
||||||
include("model.jl")
|
|
||||||
include("data.jl")
|
|
||||||
include("training.jl")
|
|
||||||
|
|
||||||
include("compiler/code.jl")
|
include("compiler/code.jl")
|
||||||
include("compiler/loops.jl")
|
include("compiler/loops.jl")
|
||||||
include("compiler/interp.jl")
|
include("compiler/interp.jl")
|
||||||
include("compiler/shape.jl")
|
include("compiler/shape.jl")
|
||||||
|
|
||||||
|
include("layers/control.jl")
|
||||||
include("layers/affine.jl")
|
include("layers/affine.jl")
|
||||||
include("layers/activation.jl")
|
include("layers/activation.jl")
|
||||||
include("layers/cost.jl")
|
include("layers/cost.jl")
|
||||||
include("layers/recurrent.jl")
|
include("layers/recurrent.jl")
|
||||||
include("layers/chain.jl")
|
|
||||||
include("layers/shims.jl")
|
include("layers/shims.jl")
|
||||||
|
|
||||||
include("backend/backend.jl")
|
include("backend/backend.jl")
|
||||||
|
|
||||||
|
include("data.jl")
|
||||||
|
include("training.jl")
|
||||||
|
|
||||||
end # module
|
end # module
|
||||||
|
@ -26,3 +26,25 @@ function rebatchseq(xs)
|
|||||||
B = Array{eltype(xs),dims+2}
|
B = Array{eltype(xs),dims+2}
|
||||||
Batch{Seq{T,S},B}(xs)
|
Batch{Seq{T,S},B}(xs)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
# SeqModel wrapper layer for convenience
|
||||||
|
|
||||||
|
struct SeqModel <: Model
|
||||||
|
model
|
||||||
|
steps::Int
|
||||||
|
end
|
||||||
|
|
||||||
|
runseq(f, xs::Tuple...) = f(xs...)
|
||||||
|
runseq(f, xs::AbstractArray...) = stack(f(map(x -> (unstack(x,2)...,), xs)...), 2)
|
||||||
|
runseq(f, xs::BatchSeq...) = rebatchseq(runseq(f, rawbatch.(xs)...))
|
||||||
|
|
||||||
|
function (m::SeqModel)(x)
|
||||||
|
runseq(x) do x
|
||||||
|
@assert length(x) == m.steps "Expected seq length $(m.steps), got $(size(x, 2))"
|
||||||
|
m.model(x)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
back!(m::SeqModel, Δ, x) = (runseq((Δ, x) -> back!(m.model, Δ, x)[1], Δ, x),)
|
||||||
|
|
||||||
|
update!(m::SeqModel, η) = update!(m.model, η)
|
||||||
|
@ -37,3 +37,26 @@ macro Chain(x, xs...)
|
|||||||
c
|
c
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
# Stateful Models
|
||||||
|
|
||||||
|
mutable struct Stateful <: Model
|
||||||
|
model
|
||||||
|
istate::Vector{Any}
|
||||||
|
ostate::Vector{Any}
|
||||||
|
end
|
||||||
|
|
||||||
|
Stateful(model, state) = Stateful(model, state, state)
|
||||||
|
|
||||||
|
function (m::Stateful)(x)
|
||||||
|
m.istate = m.ostate
|
||||||
|
state, y = m.model((m.istate...,), x)
|
||||||
|
m.ostate = collect(state)
|
||||||
|
return y
|
||||||
|
end
|
||||||
|
|
||||||
|
function back!(m::Stateful, Δ, x)
|
||||||
|
back!(m.model, ((zeros.(m.ostate)...,), Δ), (m.istate...,), x)[2:end]
|
||||||
|
end
|
||||||
|
|
||||||
|
update!(m::Stateful, η) = update!(m.model, η)
|
46
src/model.jl
46
src/model.jl
@ -102,49 +102,3 @@ end
|
|||||||
(m::Capacitor)(xs...) = interpmodel(m, xs...)
|
(m::Capacitor)(xs...) = interpmodel(m, xs...)
|
||||||
|
|
||||||
graph(cap::Capacitor) = cap.graph
|
graph(cap::Capacitor) = cap.graph
|
||||||
|
|
||||||
# Recurrent Models
|
|
||||||
|
|
||||||
mutable struct Stateful <: Model
|
|
||||||
model
|
|
||||||
istate::Vector{Any}
|
|
||||||
ostate::Vector{Any}
|
|
||||||
end
|
|
||||||
|
|
||||||
Stateful(model, state) = Stateful(model, state, state)
|
|
||||||
|
|
||||||
function (m::Stateful)(x)
|
|
||||||
m.istate = m.ostate
|
|
||||||
state, y = m.model((m.istate...,), x)
|
|
||||||
m.ostate = collect(state)
|
|
||||||
return y
|
|
||||||
end
|
|
||||||
|
|
||||||
function back!(m::Stateful, Δ, x)
|
|
||||||
back!(m.model, ((zeros.(m.ostate)...,), Δ), (m.istate...,), x)[2:end]
|
|
||||||
end
|
|
||||||
|
|
||||||
update!(m::Stateful, η) = update!(m.model, η)
|
|
||||||
|
|
||||||
stateless(m) = m
|
|
||||||
stateless(m::Stateful) = m.model
|
|
||||||
|
|
||||||
struct SeqModel <: Model
|
|
||||||
model
|
|
||||||
steps::Int
|
|
||||||
end
|
|
||||||
|
|
||||||
runseq(f, xs::Tuple...) = f(xs...)
|
|
||||||
runseq(f, xs::AbstractArray...) = stack(f(map(x -> (unstack(x,2)...,), xs)...), 2)
|
|
||||||
runseq(f, xs::BatchSeq...) = rebatchseq(runseq(f, rawbatch.(xs)...))
|
|
||||||
|
|
||||||
function (m::SeqModel)(x)
|
|
||||||
runseq(x) do x
|
|
||||||
@assert length(x) == m.steps "Expected seq length $(m.steps), got $(size(x, 2))"
|
|
||||||
m.model(x)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
back!(m::SeqModel, Δ, x) = (runseq((Δ, x) -> back!(m.model, Δ, x)[1], Δ, x),)
|
|
||||||
|
|
||||||
update!(m::SeqModel, η) = update!(m.model, η)
|
|
||||||
|
Loading…
Reference in New Issue
Block a user