nicer mxnet api
This commit is contained in:
parent
9c9feb9ba0
commit
854a1e1886
|
@ -18,7 +18,7 @@ directly won't have great performance. In order to run a computationally intensi
|
|||
This is easy to do. Just call either `mxnet` or `tf` on a model to convert it to a model of that kind:
|
||||
|
||||
```julia
|
||||
mxmodel = mxnet(model, (10, 1))
|
||||
mxmodel = mxnet(model)
|
||||
mxmodel(xs) #> [0.0650, 0.0655, ...]
|
||||
# or
|
||||
tfmodel = tf(model)
|
||||
|
|
|
@ -14,18 +14,20 @@ end
|
|||
|
||||
model = TLP(Affine(10, 20), Affine(21, 15))
|
||||
|
||||
mxmodel = mxnet(model, (10, 1))
|
||||
mxmodel = mxnet(model)
|
||||
|
||||
mxmodel(rand(10))
|
||||
```
|
||||
|
||||
Unfortunately, this model has a (fairly obvious) typo, which means that the code above won't run. Instead we get an error message:
|
||||
|
||||
```julia
|
||||
InferShape Error in dot5: [20:37:39] src/operator/./matrix_op-inl.h:271:
|
||||
Check failed: (lshape[1]) == (rshape[0]) dot shape error: (15,21) X (20,1)
|
||||
in Flux.Affine at affine.jl:8
|
||||
in TLP at test.jl:6
|
||||
in mxnet(::TLP, ::Tuple{Int64,Int64}) at model.jl:40
|
||||
in mxnet(::TLP, ::Vararg{Any,N} where N) at backend.jl:20
|
||||
Error in operator dot2: [21:28:21] src/operator/tensor/./matrix_op-inl.h:460:
|
||||
Check failed: lshape[1] == rshape[0] (20 vs. 21) dot shape error: (1,20) X (21,15)
|
||||
Flux.Affine at affine.jl:8
|
||||
TLP at basic.jl:6
|
||||
(::Flux.MX.Model)(::Flux.Batch{Array{Float64,1},Array{Float64,2}}) at model.jl:105
|
||||
(::Flux.MX.Model)(::Array{Float64,1}) at model.jl:107
|
||||
```
|
||||
|
||||
Most frameworks would only give the error message here – not so helpful if you have thousands of nodes in your computational graph. However, Flux is able to give good error reports *even when no Julia code has been run*, e.g. when running on a backend like MXNet. This enables us to pinpoint the source of the error very quickly even in a large model.
|
||||
|
|
|
@ -15,7 +15,7 @@ function loadmx()
|
|||
@eval include(joinpath(dirname($@__FILE__), "mxnet/mxnet.jl"))
|
||||
end
|
||||
|
||||
function mxnet(args...)
|
||||
function mxnet(m)
|
||||
loadmx()
|
||||
eval(:(MX.mxnet($(args...))))
|
||||
eval(:(MX.mxnet($m)))
|
||||
end
|
||||
|
|
|
@ -8,6 +8,12 @@ end
|
|||
|
||||
Base.size(p::AlterParam) = size(p.load(p.param.x))
|
||||
|
||||
function copyargs!(as, bs)
|
||||
for id in intersect(keys(as), keys(bs))
|
||||
copy!(as[id], bs[id])
|
||||
end
|
||||
end
|
||||
|
||||
type Graph
|
||||
output
|
||||
params::Dict{Symbol,Any}
|
||||
|
@ -22,75 +28,95 @@ function mxparams(g::Graph)
|
|||
return params
|
||||
end
|
||||
|
||||
function copyargs!(as, bs)
|
||||
for id in intersect(keys(as), keys(bs))
|
||||
copy!(as[id], bs[id])
|
||||
end
|
||||
end
|
||||
|
||||
ndparams(d::Dict{Symbol,MXArray}) = Dict(k => v.data for (k, v) in d)
|
||||
|
||||
type Model <: Flux.Model
|
||||
model::Any
|
||||
type Exec <: Flux.Model
|
||||
graph::Graph
|
||||
exec::mx.Executor
|
||||
args::Dict{Symbol,MXArray}
|
||||
grads::Dict{Symbol,MXArray}
|
||||
outs::Vector{MXArray}
|
||||
exec::mx.Executor
|
||||
end
|
||||
|
||||
loadparams!(model::Model) = copyargs!(model.args, model.graph.params)
|
||||
storeparams!(model::Model) = copyargs!(model.graph.params, model.args)
|
||||
loadparams!(exec::Exec) = copyargs!(exec.args, exec.graph.params)
|
||||
storeparams!(exec::Exec) = copyargs!(exec.graph.params, exec.args)
|
||||
|
||||
mxgroup(x) = x
|
||||
mxgroup(x::Tuple) = mx.Group(mxgroup.(x)...)
|
||||
mxungroup(x, outs) = copy(shift!(outs))
|
||||
mxungroup(x::Tuple, outs) = map(x -> mxungroup(x, outs), x)
|
||||
|
||||
function mxnet(model::Flux.Model, input)
|
||||
graph = tograph(model, mx.Variable(:input))
|
||||
function executor(graph::Graph, input)
|
||||
args = merge(mxparams(graph), Dict(:input => MXArray(input)))
|
||||
grads = merge(mxparams(graph), Dict(:input => MXArray(input)))
|
||||
exec = @mxerr graph.stacks mx.bind(mxgroup(graph.output),
|
||||
args = ndparams(args),
|
||||
args_grad = ndparams(grads),
|
||||
grad_req = mx.GRAD_ADD)
|
||||
model = Model(model, graph, args, grads, MXArray.(exec.outputs), exec)
|
||||
loadparams!(model)
|
||||
return model
|
||||
exec = mx.bind(mxgroup(graph.output),
|
||||
args = ndparams(args),
|
||||
args_grad = ndparams(grads),
|
||||
grad_req = mx.GRAD_ADD)
|
||||
exec = Exec(graph, exec, args, grads, MXArray.(exec.outputs))
|
||||
loadparams!(exec)
|
||||
return exec
|
||||
end
|
||||
|
||||
function runmodel(model::Model, input)
|
||||
copy!(model.args[:input], input)
|
||||
mx.forward(model.exec, is_train = true)
|
||||
mxungroup(model.graph.output, copy(model.outs))
|
||||
function (exec::Exec)(input)
|
||||
copy!(exec.args[:input], input)
|
||||
mx.forward(exec.exec, is_train = true)
|
||||
mxungroup(exec.graph.output, copy(exec.outs))
|
||||
end
|
||||
|
||||
(m::Model)(x::Batch) = rebatch(runmodel(m, rawbatch(x)))
|
||||
|
||||
(m::Model)(x) = unbatchone(m(batchone(x)))
|
||||
|
||||
function runback!(model::Model, Δ)
|
||||
model.grads[:input][:] = 0
|
||||
mx.backward(model.exec, MXArray(Δ).data)
|
||||
copy(model.grads[:input])
|
||||
function Flux.back!(exec::Exec, Δ)
|
||||
exec.grads[:input][:] = 0
|
||||
mx.backward(exec.exec, MXArray(Δ).data)
|
||||
copy(exec.grads[:input])
|
||||
end
|
||||
|
||||
Flux.back!(m::Model, Δ::Batch, x) = rebatch(runback!(m, rawbatch(Δ)))
|
||||
|
||||
Flux.back!(m::Model, Δ, x) = first(Flux.back!(m, batchone(Δ), x))
|
||||
|
||||
function Flux.update!(model::Model, η)
|
||||
for (arg, grad) in zip(model.exec.arg_arrays, model.exec.grad_arrays)
|
||||
function Flux.update!(exec::Exec, η)
|
||||
for (arg, grad) in zip(exec.exec.arg_arrays, exec.exec.grad_arrays)
|
||||
mx.@nd_as_jl rw = (arg, grad) begin
|
||||
arg .-= grad .* η
|
||||
grad[:] = 0
|
||||
end
|
||||
end
|
||||
storeparams!(model)
|
||||
return model
|
||||
storeparams!(exec)
|
||||
return exec
|
||||
end
|
||||
|
||||
# TODO: if `last` changes, update params appropriately
|
||||
|
||||
type Model
|
||||
model::Any
|
||||
graph::Graph
|
||||
execs::Dict{Tuple,Exec}
|
||||
last::Exec
|
||||
Model(model, graph, execs) = new(model, graph, execs)
|
||||
end
|
||||
|
||||
function mxnet(model)
|
||||
graph = tograph(model, mx.Variable(:input))
|
||||
Model(model, graph, Dict())
|
||||
end
|
||||
|
||||
import Base: @get!
|
||||
|
||||
executor(m::Model, input) = @get!(m.execs, input, executor(m.graph, input))
|
||||
|
||||
function (m::Model)(x::Batch)
|
||||
x′ = rawbatch(x)
|
||||
m.last = exec = @mxerr m.graph.stacks executor(m, size(x′))
|
||||
rebatch(exec(x′))
|
||||
end
|
||||
|
||||
(m::Model)(x) = unbatchone(m(batchone(x)))
|
||||
|
||||
function Flux.back!(m::Model, Δ::Batch, x::Batch)
|
||||
m.last = exec = m.execs[size(rawbatch(x))]
|
||||
rebatch(back!(exec, rawbatch(Δ)))
|
||||
end
|
||||
|
||||
Flux.back!(m::Model, Δ, x) = first(Flux.back!(m, batchone(Δ), batchone(x)))
|
||||
|
||||
Flux.update!(m::Model, η) = (update!(m.last, η); m)
|
||||
|
||||
# MX FeedForward interface
|
||||
|
||||
type SoftmaxOutput
|
||||
|
|
|
@ -6,11 +6,11 @@ Flux.loadmx()
|
|||
xs = rand(20)
|
||||
d = Affine(20, 10)
|
||||
|
||||
dm = mxnet(d, (1, 20))
|
||||
dm = mxnet(d)
|
||||
@test d(xs) ≈ dm(xs)
|
||||
|
||||
m = Multi(20, 15)
|
||||
mm = mxnet(m, (1, 20))
|
||||
mm = mxnet(m)
|
||||
@test all(isapprox.(mm(xs), m(xs)))
|
||||
|
||||
@testset "Backward Pass" begin
|
||||
|
@ -40,7 +40,8 @@ end
|
|||
@testset "Stack Traces" begin
|
||||
model = TLP(Affine(10, 20), Affine(21, 15))
|
||||
info("The following warning is normal")
|
||||
e = try mxnet(model, (10, 1))
|
||||
dm = mxnet(model)
|
||||
e = try dm(rand(10))
|
||||
catch e e end
|
||||
|
||||
@test isa(e, DataFlow.Interpreter.Exception)
|
||||
|
|
Loading…
Reference in New Issue