mimo working in mxnet
This commit is contained in:
parent
94e384930d
commit
4113d4d476
|
@ -106,20 +106,21 @@ mxnet(model) = Model(model)
|
|||
|
||||
import Base: @get!
|
||||
|
||||
executor(m::Model, input) = @get!(m.execs, size(input), executor(m.graph, input))
|
||||
# TODO: dims having its own type would be useful
|
||||
executor(m::Model, input...) = @get!(m.execs, mapt(size, input), executor(m.graph, input...))
|
||||
|
||||
function (m::Model)(x)
|
||||
function (m::Model)(xs...)
|
||||
!isdefined(m, :graph) &&
|
||||
(m.graph = tograph(m.model, mapt(_ -> gensym("input"), input)))
|
||||
@mxerr m.graph.stacks runrawbatched(x) do x
|
||||
m.last = exec = executor(m, x)
|
||||
exec(x)
|
||||
(m.graph = tograph(m.model, mapt(_ -> gensym("input"), xs)...))
|
||||
@mxerr m.graph.stacks runrawbatched(xs) do xs
|
||||
m.last = exec = executor(m, xs...)
|
||||
exec(xs...)
|
||||
end
|
||||
end
|
||||
|
||||
function Flux.back!(m::Model, Δ, x)
|
||||
runrawbatched(Δ, x) do Δ, x
|
||||
m.last = exec = m.execs[size(x)]
|
||||
function Flux.back!(m::Model, Δ, xs...)
|
||||
runrawbatched(Δ, xs) do Δ, xs
|
||||
m.last = exec = m.execs[mapt(size, xs)]
|
||||
back!(exec, Δ)
|
||||
end
|
||||
end
|
||||
|
|
|
@ -3,15 +3,15 @@ Flux.loadmx()
|
|||
|
||||
@testset "MXNet" begin
|
||||
|
||||
xs = rand(20)
|
||||
xs, ys = rand(20), rand(20)
|
||||
d = Affine(20, 10)
|
||||
|
||||
dm = mxnet(d)
|
||||
@test d(xs) ≈ dm(xs)
|
||||
|
||||
# m = Multi(20, 15)
|
||||
# mm = mxnet(m)
|
||||
# @test all(isapprox.(mm(xs), m(xs)))
|
||||
m = Multi(20, 15)
|
||||
mm = mxnet(m)
|
||||
@test all(isapprox.(mm(xs, ys), m(xs, ys)))
|
||||
|
||||
@testset "Backward Pass" begin
|
||||
d′ = deepcopy(d)
|
||||
|
|
Loading…
Reference in New Issue