recurrence overhaul mostly working

This commit is contained in:
Mike J Innes 2017-03-21 01:32:12 +00:00
parent 90edefe072
commit 1a5e050a88
3 changed files with 38 additions and 16 deletions

View File

@ -40,3 +40,5 @@ function interpmodel_(m, args...)
end
interpmodel(m, args...) = @ithrow runrawbatched((xs...) -> interpmodel_(m, xs...), args...)
runmodel(m::Capacitor, xs...) = @ithrow interpmodel_(m, xs...)

View File

@ -104,22 +104,9 @@ end
unrollgraph(m, n; kws...) = unrollgraph(atomise(m), n; kws...)
# TODO: perhaps split into SeqModel + StatefulModel
type Unrolled <: Model
model
graph::IVertex{Any}
state::Vector{Any}
stateful::Bool
steps::Int
end
(m::Unrolled)(xs...) = interpret(reifyparams(m.graph), xs...)
graph(u::Unrolled) = u.graph
function unroll(model, n)
graph, state = unrollgraph(model, n)
Unrolled(model, graph, state, true, n)
SeqModel(Stateful(Capacitor(graph), state), n)
end
function unseqin(v::IVertex)
@ -135,7 +122,7 @@ unseq(graph) = unseqout(unseqin(graph))
function unroll1(model)
graph, state = unrollgraph(model, 1)
Unrolled(model, unseq(graph), state, false, 1)
Stateful(Capacitor(graph), state)
end
flip(model) = Capacitor(map(x -> x isa Offset ? -x : x, atomise(model)))

View File

@ -46,6 +46,16 @@ methods as necessary.
"""
graph(m) = nothing
"""
`runmodel(m, ...)` is like `m(...)`, i.e. it runs the forward pass. However,
unlike direct calling, it does not try to apply batching and simply uses the
inputs directly.
This function should be considered an implementation detail; it will be
eventually be replaced by a non-hacky way of doing batching.
"""
function runmodel end
# Model parameters
# TODO: should be AbstractArray?
@ -111,7 +121,30 @@ struct Capacitor <: Model
graph::IVertex{Any}
end
# TODO: batching
(m::Capacitor)(xs...) = interpmodel(m, xs...)
graph(cap::Capacitor) = cap.graph
# Recurrent Models
mutable struct Stateful <: Model
model
state::Vector{Any}
end
function (m::Stateful)(x)
runrawbatched(x) do x
state, y = runmodel(m.model, (m.state...,), x)
m.state = collect(state)
return y
end
end
struct SeqModel
model
steps::Int
end
# TODO: multi input
# TODO: lift sequences
(m::SeqModel)(x) = m.model(x)