mimo working in mxnet

This commit is contained in:
Mike J Innes 2017-03-30 19:50:03 +01:00
parent 94e384930d
commit 4113d4d476
2 changed files with 14 additions and 13 deletions

View File

@ -106,20 +106,21 @@ mxnet(model) = Model(model)
import Base: @get!
executor(m::Model, input) = @get!(m.execs, size(input), executor(m.graph, input))
# TODO: dims having its own type would be useful
executor(m::Model, input...) = @get!(m.execs, mapt(size, input), executor(m.graph, input...))
function (m::Model)(x)
function (m::Model)(xs...)
!isdefined(m, :graph) &&
(m.graph = tograph(m.model, mapt(_ -> gensym("input"), input)))
@mxerr m.graph.stacks runrawbatched(x) do x
m.last = exec = executor(m, x)
exec(x)
(m.graph = tograph(m.model, mapt(_ -> gensym("input"), xs)...))
@mxerr m.graph.stacks runrawbatched(xs) do xs
m.last = exec = executor(m, xs...)
exec(xs...)
end
end
function Flux.back!(m::Model, Δ, x)
runrawbatched(Δ, x) do Δ, x
m.last = exec = m.execs[size(x)]
function Flux.back!(m::Model, Δ, xs...)
runrawbatched(Δ, xs) do Δ, xs
m.last = exec = m.execs[mapt(size, xs)]
back!(exec, Δ)
end
end

View File

@ -3,15 +3,15 @@ Flux.loadmx()
@testset "MXNet" begin
xs = rand(20)
xs, ys = rand(20), rand(20)
d = Affine(20, 10)
dm = mxnet(d)
@test d(xs) dm(xs)
# m = Multi(20, 15)
# mm = mxnet(m)
# @test all(isapprox.(mm(xs), m(xs)))
m = Multi(20, 15)
mm = mxnet(m)
@test all(isapprox.(mm(xs, ys), m(xs, ys)))
@testset "Backward Pass" begin
d = deepcopy(d)