recurrence working
This commit is contained in:
parent
298c6f252e
commit
b4221f6ea6
@ -109,13 +109,15 @@ import Base: @get!
|
|||||||
# TODO: dims having its own type would be useful
|
# TODO: dims having its own type would be useful
|
||||||
executor(m::Model, input...) = @get!(m.execs, mapt(size, input), executor(m.graph, input...))
|
executor(m::Model, input...) = @get!(m.execs, mapt(size, input), executor(m.graph, input...))
|
||||||
|
|
||||||
function (m::Model)(xs...)
|
function Flux.runmodel(m::Model, xs...)
|
||||||
!isdefined(m, :graph) &&
|
!isdefined(m, :graph) &&
|
||||||
(m.graph = tograph(m.model, mapt(_ -> gensym("input"), xs)...))
|
(m.graph = tograph(m.model, mapt(_ -> gensym("input"), xs)...))
|
||||||
@mxerr m.graph.stacks runrawbatched(xs) do xs
|
m.last = exec = executor(m, xs...)
|
||||||
m.last = exec = executor(m, xs...)
|
exec(xs...)
|
||||||
exec(xs...)
|
end
|
||||||
end
|
|
||||||
|
function (m::Model)(xs...)
|
||||||
|
@mxerr m.graph.stacks runrawbatched(xs -> Flux.runmodel(m, xs...), xs)
|
||||||
end
|
end
|
||||||
|
|
||||||
function Flux.back!(m::Model, Δ, xs...)
|
function Flux.back!(m::Model, Δ, xs...)
|
||||||
@ -127,6 +129,13 @@ end
|
|||||||
|
|
||||||
Flux.update!(m::Model, η) = (update!(m.last, η); m)
|
Flux.update!(m::Model, η) = (update!(m.last, η); m)
|
||||||
|
|
||||||
|
# Recurrent Models
|
||||||
|
|
||||||
|
using Flux: Stateful, SeqModel
|
||||||
|
|
||||||
|
mxnet(m::Stateful) = Stateful(mxnet(m.model), m.state)
|
||||||
|
mxnet(m::SeqModel) = SeqModel(mxnet(m.model), m.steps)
|
||||||
|
|
||||||
# MX FeedForward interface
|
# MX FeedForward interface
|
||||||
|
|
||||||
struct SoftmaxOutput
|
struct SoftmaxOutput
|
||||||
|
Loading…
Reference in New Issue
Block a user