recurrence overhaul mostly working
This commit is contained in:
parent
90edefe072
commit
1a5e050a88
@ -40,3 +40,5 @@ function interpmodel_(m, args...)
|
|||||||
end
|
end
|
||||||
|
|
||||||
interpmodel(m, args...) = @ithrow runrawbatched((xs...) -> interpmodel_(m, xs...), args...)
|
interpmodel(m, args...) = @ithrow runrawbatched((xs...) -> interpmodel_(m, xs...), args...)
|
||||||
|
|
||||||
|
runmodel(m::Capacitor, xs...) = @ithrow interpmodel_(m, xs...)
|
||||||
|
@ -104,22 +104,9 @@ end
|
|||||||
|
|
||||||
unrollgraph(m, n; kws...) = unrollgraph(atomise(m), n; kws...)
|
unrollgraph(m, n; kws...) = unrollgraph(atomise(m), n; kws...)
|
||||||
|
|
||||||
# TODO: perhaps split into SeqModel + StatefulModel
|
|
||||||
type Unrolled <: Model
|
|
||||||
model
|
|
||||||
graph::IVertex{Any}
|
|
||||||
state::Vector{Any}
|
|
||||||
stateful::Bool
|
|
||||||
steps::Int
|
|
||||||
end
|
|
||||||
|
|
||||||
(m::Unrolled)(xs...) = interpret(reifyparams(m.graph), xs...)
|
|
||||||
|
|
||||||
graph(u::Unrolled) = u.graph
|
|
||||||
|
|
||||||
function unroll(model, n)
|
function unroll(model, n)
|
||||||
graph, state = unrollgraph(model, n)
|
graph, state = unrollgraph(model, n)
|
||||||
Unrolled(model, graph, state, true, n)
|
SeqModel(Stateful(Capacitor(graph), state), n)
|
||||||
end
|
end
|
||||||
|
|
||||||
function unseqin(v::IVertex)
|
function unseqin(v::IVertex)
|
||||||
@ -135,7 +122,7 @@ unseq(graph) = unseqout(unseqin(graph))
|
|||||||
|
|
||||||
function unroll1(model)
|
function unroll1(model)
|
||||||
graph, state = unrollgraph(model, 1)
|
graph, state = unrollgraph(model, 1)
|
||||||
Unrolled(model, unseq(graph), state, false, 1)
|
Stateful(Capacitor(graph), state)
|
||||||
end
|
end
|
||||||
|
|
||||||
flip(model) = Capacitor(map(x -> x isa Offset ? -x : x, atomise(model)))
|
flip(model) = Capacitor(map(x -> x isa Offset ? -x : x, atomise(model)))
|
||||||
|
35
src/model.jl
35
src/model.jl
@ -46,6 +46,16 @@ methods as necessary.
|
|||||||
"""
|
"""
|
||||||
graph(m) = nothing
|
graph(m) = nothing
|
||||||
|
|
||||||
|
"""
|
||||||
|
`runmodel(m, ...)` is like `m(...)`, i.e. it runs the forward pass. However,
|
||||||
|
unlike direct calling, it does not try to apply batching and simply uses the
|
||||||
|
inputs directly.
|
||||||
|
|
||||||
|
This function should be considered an implementation detail; it will be
|
||||||
|
eventually be replaced by a non-hacky way of doing batching.
|
||||||
|
"""
|
||||||
|
function runmodel end
|
||||||
|
|
||||||
# Model parameters
|
# Model parameters
|
||||||
|
|
||||||
# TODO: should be AbstractArray?
|
# TODO: should be AbstractArray?
|
||||||
@ -111,7 +121,30 @@ struct Capacitor <: Model
|
|||||||
graph::IVertex{Any}
|
graph::IVertex{Any}
|
||||||
end
|
end
|
||||||
|
|
||||||
# TODO: batching
|
|
||||||
(m::Capacitor)(xs...) = interpmodel(m, xs...)
|
(m::Capacitor)(xs...) = interpmodel(m, xs...)
|
||||||
|
|
||||||
graph(cap::Capacitor) = cap.graph
|
graph(cap::Capacitor) = cap.graph
|
||||||
|
|
||||||
|
# Recurrent Models
|
||||||
|
|
||||||
|
mutable struct Stateful <: Model
|
||||||
|
model
|
||||||
|
state::Vector{Any}
|
||||||
|
end
|
||||||
|
|
||||||
|
function (m::Stateful)(x)
|
||||||
|
runrawbatched(x) do x
|
||||||
|
state, y = runmodel(m.model, (m.state...,), x)
|
||||||
|
m.state = collect(state)
|
||||||
|
return y
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
struct SeqModel
|
||||||
|
model
|
||||||
|
steps::Int
|
||||||
|
end
|
||||||
|
|
||||||
|
# TODO: multi input
|
||||||
|
# TODO: lift sequences
|
||||||
|
(m::SeqModel)(x) = m.model(x)
|
||||||
|
Loading…
Reference in New Issue
Block a user