recurrence overhaul mostly working
This commit is contained in:
parent
90edefe072
commit
1a5e050a88
@ -40,3 +40,5 @@ function interpmodel_(m, args...)
|
||||
end
|
||||
|
||||
interpmodel(m, args...) = @ithrow runrawbatched((xs...) -> interpmodel_(m, xs...), args...)
|
||||
|
||||
runmodel(m::Capacitor, xs...) = @ithrow interpmodel_(m, xs...)
|
||||
|
@ -104,22 +104,9 @@ end
|
||||
|
||||
unrollgraph(m, n; kws...) = unrollgraph(atomise(m), n; kws...)
|
||||
|
||||
# TODO: perhaps split into SeqModel + StatefulModel
|
||||
type Unrolled <: Model
|
||||
model
|
||||
graph::IVertex{Any}
|
||||
state::Vector{Any}
|
||||
stateful::Bool
|
||||
steps::Int
|
||||
end
|
||||
|
||||
(m::Unrolled)(xs...) = interpret(reifyparams(m.graph), xs...)
|
||||
|
||||
graph(u::Unrolled) = u.graph
|
||||
|
||||
function unroll(model, n)
|
||||
graph, state = unrollgraph(model, n)
|
||||
Unrolled(model, graph, state, true, n)
|
||||
SeqModel(Stateful(Capacitor(graph), state), n)
|
||||
end
|
||||
|
||||
function unseqin(v::IVertex)
|
||||
@ -135,7 +122,7 @@ unseq(graph) = unseqout(unseqin(graph))
|
||||
|
||||
function unroll1(model)
|
||||
graph, state = unrollgraph(model, 1)
|
||||
Unrolled(model, unseq(graph), state, false, 1)
|
||||
Stateful(Capacitor(graph), state)
|
||||
end
|
||||
|
||||
flip(model) = Capacitor(map(x -> x isa Offset ? -x : x, atomise(model)))
|
||||
|
35
src/model.jl
35
src/model.jl
@ -46,6 +46,16 @@ methods as necessary.
|
||||
"""
|
||||
graph(m) = nothing
|
||||
|
||||
"""
|
||||
`runmodel(m, ...)` is like `m(...)`, i.e. it runs the forward pass. However,
|
||||
unlike direct calling, it does not try to apply batching and simply uses the
|
||||
inputs directly.
|
||||
|
||||
This function should be considered an implementation detail; it will be
|
||||
eventually be replaced by a non-hacky way of doing batching.
|
||||
"""
|
||||
function runmodel end
|
||||
|
||||
# Model parameters
|
||||
|
||||
# TODO: should be AbstractArray?
|
||||
@ -111,7 +121,30 @@ struct Capacitor <: Model
|
||||
graph::IVertex{Any}
|
||||
end
|
||||
|
||||
# TODO: batching
|
||||
(m::Capacitor)(xs...) = interpmodel(m, xs...)
|
||||
|
||||
graph(cap::Capacitor) = cap.graph
|
||||
|
||||
# Recurrent Models
|
||||
|
||||
mutable struct Stateful <: Model
|
||||
model
|
||||
state::Vector{Any}
|
||||
end
|
||||
|
||||
function (m::Stateful)(x)
|
||||
runrawbatched(x) do x
|
||||
state, y = runmodel(m.model, (m.state...,), x)
|
||||
m.state = collect(state)
|
||||
return y
|
||||
end
|
||||
end
|
||||
|
||||
struct SeqModel
|
||||
model
|
||||
steps::Int
|
||||
end
|
||||
|
||||
# TODO: multi input
|
||||
# TODO: lift sequences
|
||||
(m::SeqModel)(x) = m.model(x)
|
||||
|
Loading…
Reference in New Issue
Block a user