nicer mxnet api

This commit is contained in:
Mike J Innes 2017-03-08 21:41:13 +00:00
parent 9c9feb9ba0
commit 854a1e1886
5 changed files with 82 additions and 53 deletions

View File

@ -18,7 +18,7 @@ directly won't have great performance. In order to run a computationally intensi
This is easy to do. Just call either `mxnet` or `tf` on a model to convert it to a model of that kind: This is easy to do. Just call either `mxnet` or `tf` on a model to convert it to a model of that kind:
```julia ```julia
mxmodel = mxnet(model, (10, 1)) mxmodel = mxnet(model)
mxmodel(xs) #> [0.0650, 0.0655, ...] mxmodel(xs) #> [0.0650, 0.0655, ...]
# or # or
tfmodel = tf(model) tfmodel = tf(model)

View File

@ -14,18 +14,20 @@ end
model = TLP(Affine(10, 20), Affine(21, 15)) model = TLP(Affine(10, 20), Affine(21, 15))
mxmodel = mxnet(model, (10, 1)) mxmodel = mxnet(model)
mxmodel(rand(10))
``` ```
Unfortunately, this model has a (fairly obvious) typo, which means that the code above won't run. Instead we get an error message: Unfortunately, this model has a (fairly obvious) typo, which means that the code above won't run. Instead we get an error message:
```julia ```julia
InferShape Error in dot5: [20:37:39] src/operator/./matrix_op-inl.h:271: Error in operator dot2: [21:28:21] src/operator/tensor/./matrix_op-inl.h:460:
Check failed: (lshape[1]) == (rshape[0]) dot shape error: (15,21) X (20,1) Check failed: lshape[1] == rshape[0] (20 vs. 21) dot shape error: (1,20) X (21,15)
in Flux.Affine at affine.jl:8 Flux.Affine at affine.jl:8
in TLP at test.jl:6 TLP at basic.jl:6
in mxnet(::TLP, ::Tuple{Int64,Int64}) at model.jl:40 (::Flux.MX.Model)(::Flux.Batch{Array{Float64,1},Array{Float64,2}}) at model.jl:105
in mxnet(::TLP, ::Vararg{Any,N} where N) at backend.jl:20 (::Flux.MX.Model)(::Array{Float64,1}) at model.jl:107
``` ```
Most frameworks would only give the error message here not so helpful if you have thousands of nodes in your computational graph. However, Flux is able to give good error reports *even when no Julia code has been run*, e.g. when running on a backend like MXNet. This enables us to pinpoint the source of the error very quickly even in a large model. Most frameworks would only give the error message here not so helpful if you have thousands of nodes in your computational graph. However, Flux is able to give good error reports *even when no Julia code has been run*, e.g. when running on a backend like MXNet. This enables us to pinpoint the source of the error very quickly even in a large model.

View File

@ -15,7 +15,7 @@ function loadmx()
@eval include(joinpath(dirname($@__FILE__), "mxnet/mxnet.jl")) @eval include(joinpath(dirname($@__FILE__), "mxnet/mxnet.jl"))
end end
function mxnet(args...) function mxnet(m)
loadmx() loadmx()
eval(:(MX.mxnet($(args...)))) eval(:(MX.mxnet($m)))
end end

View File

@ -8,6 +8,12 @@ end
Base.size(p::AlterParam) = size(p.load(p.param.x)) Base.size(p::AlterParam) = size(p.load(p.param.x))
function copyargs!(as, bs)
for id in intersect(keys(as), keys(bs))
copy!(as[id], bs[id])
end
end
type Graph type Graph
output output
params::Dict{Symbol,Any} params::Dict{Symbol,Any}
@ -22,75 +28,95 @@ function mxparams(g::Graph)
return params return params
end end
function copyargs!(as, bs)
for id in intersect(keys(as), keys(bs))
copy!(as[id], bs[id])
end
end
ndparams(d::Dict{Symbol,MXArray}) = Dict(k => v.data for (k, v) in d) ndparams(d::Dict{Symbol,MXArray}) = Dict(k => v.data for (k, v) in d)
type Model <: Flux.Model type Exec <: Flux.Model
model::Any
graph::Graph graph::Graph
exec::mx.Executor
args::Dict{Symbol,MXArray} args::Dict{Symbol,MXArray}
grads::Dict{Symbol,MXArray} grads::Dict{Symbol,MXArray}
outs::Vector{MXArray} outs::Vector{MXArray}
exec::mx.Executor
end end
loadparams!(model::Model) = copyargs!(model.args, model.graph.params) loadparams!(exec::Exec) = copyargs!(exec.args, exec.graph.params)
storeparams!(model::Model) = copyargs!(model.graph.params, model.args) storeparams!(exec::Exec) = copyargs!(exec.graph.params, exec.args)
mxgroup(x) = x mxgroup(x) = x
mxgroup(x::Tuple) = mx.Group(mxgroup.(x)...) mxgroup(x::Tuple) = mx.Group(mxgroup.(x)...)
mxungroup(x, outs) = copy(shift!(outs)) mxungroup(x, outs) = copy(shift!(outs))
mxungroup(x::Tuple, outs) = map(x -> mxungroup(x, outs), x) mxungroup(x::Tuple, outs) = map(x -> mxungroup(x, outs), x)
function mxnet(model::Flux.Model, input) function executor(graph::Graph, input)
graph = tograph(model, mx.Variable(:input))
args = merge(mxparams(graph), Dict(:input => MXArray(input))) args = merge(mxparams(graph), Dict(:input => MXArray(input)))
grads = merge(mxparams(graph), Dict(:input => MXArray(input))) grads = merge(mxparams(graph), Dict(:input => MXArray(input)))
exec = @mxerr graph.stacks mx.bind(mxgroup(graph.output), exec = mx.bind(mxgroup(graph.output),
args = ndparams(args), args = ndparams(args),
args_grad = ndparams(grads), args_grad = ndparams(grads),
grad_req = mx.GRAD_ADD) grad_req = mx.GRAD_ADD)
model = Model(model, graph, args, grads, MXArray.(exec.outputs), exec) exec = Exec(graph, exec, args, grads, MXArray.(exec.outputs))
loadparams!(model) loadparams!(exec)
return model return exec
end end
function runmodel(model::Model, input) function (exec::Exec)(input)
copy!(model.args[:input], input) copy!(exec.args[:input], input)
mx.forward(model.exec, is_train = true) mx.forward(exec.exec, is_train = true)
mxungroup(model.graph.output, copy(model.outs)) mxungroup(exec.graph.output, copy(exec.outs))
end end
(m::Model)(x::Batch) = rebatch(runmodel(m, rawbatch(x))) function Flux.back!(exec::Exec, Δ)
exec.grads[:input][:] = 0
(m::Model)(x) = unbatchone(m(batchone(x))) mx.backward(exec.exec, MXArray(Δ).data)
copy(exec.grads[:input])
function runback!(model::Model, Δ)
model.grads[:input][:] = 0
mx.backward(model.exec, MXArray(Δ).data)
copy(model.grads[:input])
end end
Flux.back!(m::Model, Δ::Batch, x) = rebatch(runback!(m, rawbatch(Δ))) function Flux.update!(exec::Exec, η)
for (arg, grad) in zip(exec.exec.arg_arrays, exec.exec.grad_arrays)
Flux.back!(m::Model, Δ, x) = first(Flux.back!(m, batchone(Δ), x))
function Flux.update!(model::Model, η)
for (arg, grad) in zip(model.exec.arg_arrays, model.exec.grad_arrays)
mx.@nd_as_jl rw = (arg, grad) begin mx.@nd_as_jl rw = (arg, grad) begin
arg .-= grad .* η arg .-= grad .* η
grad[:] = 0 grad[:] = 0
end end
end end
storeparams!(model) storeparams!(exec)
return model return exec
end end
# TODO: if `last` changes, update params appropriately
type Model
model::Any
graph::Graph
execs::Dict{Tuple,Exec}
last::Exec
Model(model, graph, execs) = new(model, graph, execs)
end
function mxnet(model)
graph = tograph(model, mx.Variable(:input))
Model(model, graph, Dict())
end
import Base: @get!
executor(m::Model, input) = @get!(m.execs, input, executor(m.graph, input))
function (m::Model)(x::Batch)
x = rawbatch(x)
m.last = exec = @mxerr m.graph.stacks executor(m, size(x))
rebatch(exec(x))
end
(m::Model)(x) = unbatchone(m(batchone(x)))
function Flux.back!(m::Model, Δ::Batch, x::Batch)
m.last = exec = m.execs[size(rawbatch(x))]
rebatch(back!(exec, rawbatch(Δ)))
end
Flux.back!(m::Model, Δ, x) = first(Flux.back!(m, batchone(Δ), batchone(x)))
Flux.update!(m::Model, η) = (update!(m.last, η); m)
# MX FeedForward interface # MX FeedForward interface
type SoftmaxOutput type SoftmaxOutput

View File

@ -6,11 +6,11 @@ Flux.loadmx()
xs = rand(20) xs = rand(20)
d = Affine(20, 10) d = Affine(20, 10)
dm = mxnet(d, (1, 20)) dm = mxnet(d)
@test d(xs) dm(xs) @test d(xs) dm(xs)
m = Multi(20, 15) m = Multi(20, 15)
mm = mxnet(m, (1, 20)) mm = mxnet(m)
@test all(isapprox.(mm(xs), m(xs))) @test all(isapprox.(mm(xs), m(xs)))
@testset "Backward Pass" begin @testset "Backward Pass" begin
@ -40,7 +40,8 @@ end
@testset "Stack Traces" begin @testset "Stack Traces" begin
model = TLP(Affine(10, 20), Affine(21, 15)) model = TLP(Affine(10, 20), Affine(21, 15))
info("The following warning is normal") info("The following warning is normal")
e = try mxnet(model, (10, 1)) dm = mxnet(model)
e = try dm(rand(10))
catch e e end catch e e end
@test isa(e, DataFlow.Interpreter.Exception) @test isa(e, DataFlow.Interpreter.Exception)