recurrence working
This commit is contained in:
parent
298c6f252e
commit
b4221f6ea6
@ -109,13 +109,15 @@ import Base: @get!
|
||||
# TODO: dims having its own type would be useful
|
||||
executor(m::Model, input...) = @get!(m.execs, mapt(size, input), executor(m.graph, input...))
|
||||
|
||||
function (m::Model)(xs...)
|
||||
function Flux.runmodel(m::Model, xs...)
|
||||
!isdefined(m, :graph) &&
|
||||
(m.graph = tograph(m.model, mapt(_ -> gensym("input"), xs)...))
|
||||
@mxerr m.graph.stacks runrawbatched(xs) do xs
|
||||
m.last = exec = executor(m, xs...)
|
||||
exec(xs...)
|
||||
end
|
||||
end
|
||||
|
||||
function (m::Model)(xs...)
|
||||
@mxerr m.graph.stacks runrawbatched(xs -> Flux.runmodel(m, xs...), xs)
|
||||
end
|
||||
|
||||
function Flux.back!(m::Model, Δ, xs...)
|
||||
@ -127,6 +129,13 @@ end
|
||||
|
||||
Flux.update!(m::Model, η) = (update!(m.last, η); m)
|
||||
|
||||
# Recurrent Models
|
||||
|
||||
using Flux: Stateful, SeqModel
|
||||
|
||||
mxnet(m::Stateful) = Stateful(mxnet(m.model), m.state)
|
||||
mxnet(m::SeqModel) = SeqModel(mxnet(m.model), m.steps)
|
||||
|
||||
# MX FeedForward interface
|
||||
|
||||
struct SoftmaxOutput
|
||||
|
Loading…
Reference in New Issue
Block a user