mimo working in mxnet
This commit is contained in:
parent
94e384930d
commit
4113d4d476
@ -106,20 +106,21 @@ mxnet(model) = Model(model)
|
|||||||
|
|
||||||
import Base: @get!
|
import Base: @get!
|
||||||
|
|
||||||
executor(m::Model, input) = @get!(m.execs, size(input), executor(m.graph, input))
|
# TODO: dims having its own type would be useful
|
||||||
|
executor(m::Model, input...) = @get!(m.execs, mapt(size, input), executor(m.graph, input...))
|
||||||
|
|
||||||
function (m::Model)(x)
|
function (m::Model)(xs...)
|
||||||
!isdefined(m, :graph) &&
|
!isdefined(m, :graph) &&
|
||||||
(m.graph = tograph(m.model, mapt(_ -> gensym("input"), input)))
|
(m.graph = tograph(m.model, mapt(_ -> gensym("input"), xs)...))
|
||||||
@mxerr m.graph.stacks runrawbatched(x) do x
|
@mxerr m.graph.stacks runrawbatched(xs) do xs
|
||||||
m.last = exec = executor(m, x)
|
m.last = exec = executor(m, xs...)
|
||||||
exec(x)
|
exec(xs...)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function Flux.back!(m::Model, Δ, x)
|
function Flux.back!(m::Model, Δ, xs...)
|
||||||
runrawbatched(Δ, x) do Δ, x
|
runrawbatched(Δ, xs) do Δ, xs
|
||||||
m.last = exec = m.execs[size(x)]
|
m.last = exec = m.execs[mapt(size, xs)]
|
||||||
back!(exec, Δ)
|
back!(exec, Δ)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -3,15 +3,15 @@ Flux.loadmx()
|
|||||||
|
|
||||||
@testset "MXNet" begin
|
@testset "MXNet" begin
|
||||||
|
|
||||||
xs = rand(20)
|
xs, ys = rand(20), rand(20)
|
||||||
d = Affine(20, 10)
|
d = Affine(20, 10)
|
||||||
|
|
||||||
dm = mxnet(d)
|
dm = mxnet(d)
|
||||||
@test d(xs) ≈ dm(xs)
|
@test d(xs) ≈ dm(xs)
|
||||||
|
|
||||||
# m = Multi(20, 15)
|
m = Multi(20, 15)
|
||||||
# mm = mxnet(m)
|
mm = mxnet(m)
|
||||||
# @test all(isapprox.(mm(xs), m(xs)))
|
@test all(isapprox.(mm(xs, ys), m(xs, ys)))
|
||||||
|
|
||||||
@testset "Backward Pass" begin
|
@testset "Backward Pass" begin
|
||||||
d′ = deepcopy(d)
|
d′ = deepcopy(d)
|
||||||
|
Loading…
Reference in New Issue
Block a user