nicer mxnet api

This commit is contained in:
Mike J Innes 2017-03-08 21:41:13 +00:00
parent 9c9feb9ba0
commit 854a1e1886
5 changed files with 82 additions and 53 deletions

View File

@ -18,7 +18,7 @@ directly won't have great performance. In order to run a computationally intensi
This is easy to do. Just call either `mxnet` or `tf` on a model to convert it to a model of that kind:
```julia
mxmodel = mxnet(model, (10, 1))
mxmodel = mxnet(model)
mxmodel(xs) #> [0.0650, 0.0655, ...]
# or
tfmodel = tf(model)

View File

@ -14,18 +14,20 @@ end
model = TLP(Affine(10, 20), Affine(21, 15))
mxmodel = mxnet(model, (10, 1))
mxmodel = mxnet(model)
mxmodel(rand(10))
```
Unfortunately, this model has a (fairly obvious) typo, which means that the code above won't run. Instead we get an error message:
```julia
InferShape Error in dot5: [20:37:39] src/operator/./matrix_op-inl.h:271:
Check failed: (lshape[1]) == (rshape[0]) dot shape error: (15,21) X (20,1)
in Flux.Affine at affine.jl:8
in TLP at test.jl:6
in mxnet(::TLP, ::Tuple{Int64,Int64}) at model.jl:40
in mxnet(::TLP, ::Vararg{Any,N} where N) at backend.jl:20
Error in operator dot2: [21:28:21] src/operator/tensor/./matrix_op-inl.h:460:
Check failed: lshape[1] == rshape[0] (20 vs. 21) dot shape error: (1,20) X (21,15)
Flux.Affine at affine.jl:8
TLP at basic.jl:6
(::Flux.MX.Model)(::Flux.Batch{Array{Float64,1},Array{Float64,2}}) at model.jl:105
(::Flux.MX.Model)(::Array{Float64,1}) at model.jl:107
```
Most frameworks would only give the error message here not so helpful if you have thousands of nodes in your computational graph. However, Flux is able to give good error reports *even when no Julia code has been run*, e.g. when running on a backend like MXNet. This enables us to pinpoint the source of the error very quickly even in a large model.

View File

@ -15,7 +15,7 @@ function loadmx()
@eval include(joinpath(dirname($@__FILE__), "mxnet/mxnet.jl"))
end
function mxnet(args...)
function mxnet(m)
loadmx()
eval(:(MX.mxnet($(args...))))
eval(:(MX.mxnet($m)))
end

View File

@ -8,6 +8,12 @@ end
Base.size(p::AlterParam) = size(p.load(p.param.x))
function copyargs!(as, bs)
for id in intersect(keys(as), keys(bs))
copy!(as[id], bs[id])
end
end
type Graph
output
params::Dict{Symbol,Any}
@ -22,75 +28,95 @@ function mxparams(g::Graph)
return params
end
function copyargs!(as, bs)
for id in intersect(keys(as), keys(bs))
copy!(as[id], bs[id])
end
end
ndparams(d::Dict{Symbol,MXArray}) = Dict(k => v.data for (k, v) in d)
type Model <: Flux.Model
model::Any
type Exec <: Flux.Model
graph::Graph
exec::mx.Executor
args::Dict{Symbol,MXArray}
grads::Dict{Symbol,MXArray}
outs::Vector{MXArray}
exec::mx.Executor
end
loadparams!(model::Model) = copyargs!(model.args, model.graph.params)
storeparams!(model::Model) = copyargs!(model.graph.params, model.args)
loadparams!(exec::Exec) = copyargs!(exec.args, exec.graph.params)
storeparams!(exec::Exec) = copyargs!(exec.graph.params, exec.args)
mxgroup(x) = x
mxgroup(x::Tuple) = mx.Group(mxgroup.(x)...)
mxungroup(x, outs) = copy(shift!(outs))
mxungroup(x::Tuple, outs) = map(x -> mxungroup(x, outs), x)
function mxnet(model::Flux.Model, input)
graph = tograph(model, mx.Variable(:input))
function executor(graph::Graph, input)
args = merge(mxparams(graph), Dict(:input => MXArray(input)))
grads = merge(mxparams(graph), Dict(:input => MXArray(input)))
exec = @mxerr graph.stacks mx.bind(mxgroup(graph.output),
args = ndparams(args),
args_grad = ndparams(grads),
grad_req = mx.GRAD_ADD)
model = Model(model, graph, args, grads, MXArray.(exec.outputs), exec)
loadparams!(model)
return model
exec = mx.bind(mxgroup(graph.output),
args = ndparams(args),
args_grad = ndparams(grads),
grad_req = mx.GRAD_ADD)
exec = Exec(graph, exec, args, grads, MXArray.(exec.outputs))
loadparams!(exec)
return exec
end
function runmodel(model::Model, input)
copy!(model.args[:input], input)
mx.forward(model.exec, is_train = true)
mxungroup(model.graph.output, copy(model.outs))
function (exec::Exec)(input)
copy!(exec.args[:input], input)
mx.forward(exec.exec, is_train = true)
mxungroup(exec.graph.output, copy(exec.outs))
end
(m::Model)(x::Batch) = rebatch(runmodel(m, rawbatch(x)))
(m::Model)(x) = unbatchone(m(batchone(x)))
function runback!(model::Model, Δ)
model.grads[:input][:] = 0
mx.backward(model.exec, MXArray(Δ).data)
copy(model.grads[:input])
function Flux.back!(exec::Exec, Δ)
exec.grads[:input][:] = 0
mx.backward(exec.exec, MXArray(Δ).data)
copy(exec.grads[:input])
end
Flux.back!(m::Model, Δ::Batch, x) = rebatch(runback!(m, rawbatch(Δ)))
Flux.back!(m::Model, Δ, x) = first(Flux.back!(m, batchone(Δ), x))
function Flux.update!(model::Model, η)
for (arg, grad) in zip(model.exec.arg_arrays, model.exec.grad_arrays)
function Flux.update!(exec::Exec, η)
for (arg, grad) in zip(exec.exec.arg_arrays, exec.exec.grad_arrays)
mx.@nd_as_jl rw = (arg, grad) begin
arg .-= grad .* η
grad[:] = 0
end
end
storeparams!(model)
return model
storeparams!(exec)
return exec
end
# TODO: if `last` changes, update params appropriately
type Model
model::Any
graph::Graph
execs::Dict{Tuple,Exec}
last::Exec
Model(model, graph, execs) = new(model, graph, execs)
end
function mxnet(model)
graph = tograph(model, mx.Variable(:input))
Model(model, graph, Dict())
end
import Base: @get!
executor(m::Model, input) = @get!(m.execs, input, executor(m.graph, input))
function (m::Model)(x::Batch)
x = rawbatch(x)
m.last = exec = @mxerr m.graph.stacks executor(m, size(x))
rebatch(exec(x))
end
(m::Model)(x) = unbatchone(m(batchone(x)))
function Flux.back!(m::Model, Δ::Batch, x::Batch)
m.last = exec = m.execs[size(rawbatch(x))]
rebatch(back!(exec, rawbatch(Δ)))
end
Flux.back!(m::Model, Δ, x) = first(Flux.back!(m, batchone(Δ), batchone(x)))
Flux.update!(m::Model, η) = (update!(m.last, η); m)
# MX FeedForward interface
type SoftmaxOutput

View File

@ -6,11 +6,11 @@ Flux.loadmx()
xs = rand(20)
d = Affine(20, 10)
dm = mxnet(d, (1, 20))
dm = mxnet(d)
@test d(xs) dm(xs)
m = Multi(20, 15)
mm = mxnet(m, (1, 20))
mm = mxnet(m)
@test all(isapprox.(mm(xs), m(xs)))
@testset "Backward Pass" begin
@ -40,7 +40,8 @@ end
@testset "Stack Traces" begin
model = TLP(Affine(10, 20), Affine(21, 15))
info("The following warning is normal")
e = try mxnet(model, (10, 1))
dm = mxnet(model)
e = try dm(rand(10))
catch e e end
@test isa(e, DataFlow.Interpreter.Exception)