diff --git a/src/compiler/interp.jl b/src/compiler/interp.jl index f8d3b883..bd5803fd 100644 --- a/src/compiler/interp.jl +++ b/src/compiler/interp.jl @@ -40,3 +40,5 @@ function interpmodel_(m, args...) end interpmodel(m, args...) = @ithrow runrawbatched((xs...) -> interpmodel_(m, xs...), args...) + +runmodel(m::Capacitor, xs...) = @ithrow interpmodel_(m, xs...) diff --git a/src/compiler/loops.jl b/src/compiler/loops.jl index cc09b28b..1f4a1e0d 100644 --- a/src/compiler/loops.jl +++ b/src/compiler/loops.jl @@ -104,22 +104,9 @@ end unrollgraph(m, n; kws...) = unrollgraph(atomise(m), n; kws...) -# TODO: perhaps split into SeqModel + StatefulModel -type Unrolled <: Model - model - graph::IVertex{Any} - state::Vector{Any} - stateful::Bool - steps::Int -end - -(m::Unrolled)(xs...) = interpret(reifyparams(m.graph), xs...) - -graph(u::Unrolled) = u.graph - function unroll(model, n) graph, state = unrollgraph(model, n) - Unrolled(model, graph, state, true, n) + SeqModel(Stateful(Capacitor(graph), state), n) end function unseqin(v::IVertex) @@ -135,7 +122,7 @@ unseq(graph) = unseqout(unseqin(graph)) function unroll1(model) graph, state = unrollgraph(model, 1) - Unrolled(model, unseq(graph), state, false, 1) + Stateful(Capacitor(graph), state) end flip(model) = Capacitor(map(x -> x isa Offset ? -x : x, atomise(model))) diff --git a/src/model.jl b/src/model.jl index b74700d0..9a4c199a 100644 --- a/src/model.jl +++ b/src/model.jl @@ -46,6 +46,16 @@ methods as necessary. """ graph(m) = nothing +""" +`runmodel(m, ...)` is like `m(...)`, i.e. it runs the forward pass. However, +unlike direct calling, it does not try to apply batching and simply uses the +inputs directly. + +This function should be considered an implementation detail; it will be +eventually be replaced by a non-hacky way of doing batching. +""" +function runmodel end + # Model parameters # TODO: should be AbstractArray? @@ -111,7 +121,30 @@ struct Capacitor <: Model graph::IVertex{Any} end -# TODO: batching (m::Capacitor)(xs...) = interpmodel(m, xs...) graph(cap::Capacitor) = cap.graph + +# Recurrent Models + +mutable struct Stateful <: Model + model + state::Vector{Any} +end + +function (m::Stateful)(x) + runrawbatched(x) do x + state, y = runmodel(m.model, (m.state...,), x) + m.state = collect(state) + return y + end +end + +struct SeqModel + model + steps::Int +end + +# TODO: multi input +# TODO: lift sequences +(m::SeqModel)(x) = m.model(x)