nicer mxnet api
This commit is contained in:
parent
9c9feb9ba0
commit
854a1e1886
@ -18,7 +18,7 @@ directly won't have great performance. In order to run a computationally intensi
|
|||||||
This is easy to do. Just call either `mxnet` or `tf` on a model to convert it to a model of that kind:
|
This is easy to do. Just call either `mxnet` or `tf` on a model to convert it to a model of that kind:
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
mxmodel = mxnet(model, (10, 1))
|
mxmodel = mxnet(model)
|
||||||
mxmodel(xs) #> [0.0650, 0.0655, ...]
|
mxmodel(xs) #> [0.0650, 0.0655, ...]
|
||||||
# or
|
# or
|
||||||
tfmodel = tf(model)
|
tfmodel = tf(model)
|
||||||
|
@ -14,18 +14,20 @@ end
|
|||||||
|
|
||||||
model = TLP(Affine(10, 20), Affine(21, 15))
|
model = TLP(Affine(10, 20), Affine(21, 15))
|
||||||
|
|
||||||
mxmodel = mxnet(model, (10, 1))
|
mxmodel = mxnet(model)
|
||||||
|
|
||||||
|
mxmodel(rand(10))
|
||||||
```
|
```
|
||||||
|
|
||||||
Unfortunately, this model has a (fairly obvious) typo, which means that the code above won't run. Instead we get an error message:
|
Unfortunately, this model has a (fairly obvious) typo, which means that the code above won't run. Instead we get an error message:
|
||||||
|
|
||||||
```julia
|
```julia
|
||||||
InferShape Error in dot5: [20:37:39] src/operator/./matrix_op-inl.h:271:
|
Error in operator dot2: [21:28:21] src/operator/tensor/./matrix_op-inl.h:460:
|
||||||
Check failed: (lshape[1]) == (rshape[0]) dot shape error: (15,21) X (20,1)
|
Check failed: lshape[1] == rshape[0] (20 vs. 21) dot shape error: (1,20) X (21,15)
|
||||||
in Flux.Affine at affine.jl:8
|
Flux.Affine at affine.jl:8
|
||||||
in TLP at test.jl:6
|
TLP at basic.jl:6
|
||||||
in mxnet(::TLP, ::Tuple{Int64,Int64}) at model.jl:40
|
(::Flux.MX.Model)(::Flux.Batch{Array{Float64,1},Array{Float64,2}}) at model.jl:105
|
||||||
in mxnet(::TLP, ::Vararg{Any,N} where N) at backend.jl:20
|
(::Flux.MX.Model)(::Array{Float64,1}) at model.jl:107
|
||||||
```
|
```
|
||||||
|
|
||||||
Most frameworks would only give the error message here – not so helpful if you have thousands of nodes in your computational graph. However, Flux is able to give good error reports *even when no Julia code has been run*, e.g. when running on a backend like MXNet. This enables us to pinpoint the source of the error very quickly even in a large model.
|
Most frameworks would only give the error message here – not so helpful if you have thousands of nodes in your computational graph. However, Flux is able to give good error reports *even when no Julia code has been run*, e.g. when running on a backend like MXNet. This enables us to pinpoint the source of the error very quickly even in a large model.
|
||||||
|
@ -15,7 +15,7 @@ function loadmx()
|
|||||||
@eval include(joinpath(dirname($@__FILE__), "mxnet/mxnet.jl"))
|
@eval include(joinpath(dirname($@__FILE__), "mxnet/mxnet.jl"))
|
||||||
end
|
end
|
||||||
|
|
||||||
function mxnet(args...)
|
function mxnet(m)
|
||||||
loadmx()
|
loadmx()
|
||||||
eval(:(MX.mxnet($(args...))))
|
eval(:(MX.mxnet($m)))
|
||||||
end
|
end
|
||||||
|
@ -8,6 +8,12 @@ end
|
|||||||
|
|
||||||
Base.size(p::AlterParam) = size(p.load(p.param.x))
|
Base.size(p::AlterParam) = size(p.load(p.param.x))
|
||||||
|
|
||||||
|
function copyargs!(as, bs)
|
||||||
|
for id in intersect(keys(as), keys(bs))
|
||||||
|
copy!(as[id], bs[id])
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
type Graph
|
type Graph
|
||||||
output
|
output
|
||||||
params::Dict{Symbol,Any}
|
params::Dict{Symbol,Any}
|
||||||
@ -22,75 +28,95 @@ function mxparams(g::Graph)
|
|||||||
return params
|
return params
|
||||||
end
|
end
|
||||||
|
|
||||||
function copyargs!(as, bs)
|
|
||||||
for id in intersect(keys(as), keys(bs))
|
|
||||||
copy!(as[id], bs[id])
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
ndparams(d::Dict{Symbol,MXArray}) = Dict(k => v.data for (k, v) in d)
|
ndparams(d::Dict{Symbol,MXArray}) = Dict(k => v.data for (k, v) in d)
|
||||||
|
|
||||||
type Model <: Flux.Model
|
type Exec <: Flux.Model
|
||||||
model::Any
|
|
||||||
graph::Graph
|
graph::Graph
|
||||||
|
exec::mx.Executor
|
||||||
args::Dict{Symbol,MXArray}
|
args::Dict{Symbol,MXArray}
|
||||||
grads::Dict{Symbol,MXArray}
|
grads::Dict{Symbol,MXArray}
|
||||||
outs::Vector{MXArray}
|
outs::Vector{MXArray}
|
||||||
exec::mx.Executor
|
|
||||||
end
|
end
|
||||||
|
|
||||||
loadparams!(model::Model) = copyargs!(model.args, model.graph.params)
|
loadparams!(exec::Exec) = copyargs!(exec.args, exec.graph.params)
|
||||||
storeparams!(model::Model) = copyargs!(model.graph.params, model.args)
|
storeparams!(exec::Exec) = copyargs!(exec.graph.params, exec.args)
|
||||||
|
|
||||||
mxgroup(x) = x
|
mxgroup(x) = x
|
||||||
mxgroup(x::Tuple) = mx.Group(mxgroup.(x)...)
|
mxgroup(x::Tuple) = mx.Group(mxgroup.(x)...)
|
||||||
mxungroup(x, outs) = copy(shift!(outs))
|
mxungroup(x, outs) = copy(shift!(outs))
|
||||||
mxungroup(x::Tuple, outs) = map(x -> mxungroup(x, outs), x)
|
mxungroup(x::Tuple, outs) = map(x -> mxungroup(x, outs), x)
|
||||||
|
|
||||||
function mxnet(model::Flux.Model, input)
|
function executor(graph::Graph, input)
|
||||||
graph = tograph(model, mx.Variable(:input))
|
|
||||||
args = merge(mxparams(graph), Dict(:input => MXArray(input)))
|
args = merge(mxparams(graph), Dict(:input => MXArray(input)))
|
||||||
grads = merge(mxparams(graph), Dict(:input => MXArray(input)))
|
grads = merge(mxparams(graph), Dict(:input => MXArray(input)))
|
||||||
exec = @mxerr graph.stacks mx.bind(mxgroup(graph.output),
|
exec = mx.bind(mxgroup(graph.output),
|
||||||
args = ndparams(args),
|
args = ndparams(args),
|
||||||
args_grad = ndparams(grads),
|
args_grad = ndparams(grads),
|
||||||
grad_req = mx.GRAD_ADD)
|
grad_req = mx.GRAD_ADD)
|
||||||
model = Model(model, graph, args, grads, MXArray.(exec.outputs), exec)
|
exec = Exec(graph, exec, args, grads, MXArray.(exec.outputs))
|
||||||
loadparams!(model)
|
loadparams!(exec)
|
||||||
return model
|
return exec
|
||||||
end
|
end
|
||||||
|
|
||||||
function runmodel(model::Model, input)
|
function (exec::Exec)(input)
|
||||||
copy!(model.args[:input], input)
|
copy!(exec.args[:input], input)
|
||||||
mx.forward(model.exec, is_train = true)
|
mx.forward(exec.exec, is_train = true)
|
||||||
mxungroup(model.graph.output, copy(model.outs))
|
mxungroup(exec.graph.output, copy(exec.outs))
|
||||||
end
|
end
|
||||||
|
|
||||||
(m::Model)(x::Batch) = rebatch(runmodel(m, rawbatch(x)))
|
function Flux.back!(exec::Exec, Δ)
|
||||||
|
exec.grads[:input][:] = 0
|
||||||
(m::Model)(x) = unbatchone(m(batchone(x)))
|
mx.backward(exec.exec, MXArray(Δ).data)
|
||||||
|
copy(exec.grads[:input])
|
||||||
function runback!(model::Model, Δ)
|
|
||||||
model.grads[:input][:] = 0
|
|
||||||
mx.backward(model.exec, MXArray(Δ).data)
|
|
||||||
copy(model.grads[:input])
|
|
||||||
end
|
end
|
||||||
|
|
||||||
Flux.back!(m::Model, Δ::Batch, x) = rebatch(runback!(m, rawbatch(Δ)))
|
function Flux.update!(exec::Exec, η)
|
||||||
|
for (arg, grad) in zip(exec.exec.arg_arrays, exec.exec.grad_arrays)
|
||||||
Flux.back!(m::Model, Δ, x) = first(Flux.back!(m, batchone(Δ), x))
|
|
||||||
|
|
||||||
function Flux.update!(model::Model, η)
|
|
||||||
for (arg, grad) in zip(model.exec.arg_arrays, model.exec.grad_arrays)
|
|
||||||
mx.@nd_as_jl rw = (arg, grad) begin
|
mx.@nd_as_jl rw = (arg, grad) begin
|
||||||
arg .-= grad .* η
|
arg .-= grad .* η
|
||||||
grad[:] = 0
|
grad[:] = 0
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
storeparams!(model)
|
storeparams!(exec)
|
||||||
return model
|
return exec
|
||||||
end
|
end
|
||||||
|
|
||||||
|
# TODO: if `last` changes, update params appropriately
|
||||||
|
|
||||||
|
type Model
|
||||||
|
model::Any
|
||||||
|
graph::Graph
|
||||||
|
execs::Dict{Tuple,Exec}
|
||||||
|
last::Exec
|
||||||
|
Model(model, graph, execs) = new(model, graph, execs)
|
||||||
|
end
|
||||||
|
|
||||||
|
function mxnet(model)
|
||||||
|
graph = tograph(model, mx.Variable(:input))
|
||||||
|
Model(model, graph, Dict())
|
||||||
|
end
|
||||||
|
|
||||||
|
import Base: @get!
|
||||||
|
|
||||||
|
executor(m::Model, input) = @get!(m.execs, input, executor(m.graph, input))
|
||||||
|
|
||||||
|
function (m::Model)(x::Batch)
|
||||||
|
x′ = rawbatch(x)
|
||||||
|
m.last = exec = @mxerr m.graph.stacks executor(m, size(x′))
|
||||||
|
rebatch(exec(x′))
|
||||||
|
end
|
||||||
|
|
||||||
|
(m::Model)(x) = unbatchone(m(batchone(x)))
|
||||||
|
|
||||||
|
function Flux.back!(m::Model, Δ::Batch, x::Batch)
|
||||||
|
m.last = exec = m.execs[size(rawbatch(x))]
|
||||||
|
rebatch(back!(exec, rawbatch(Δ)))
|
||||||
|
end
|
||||||
|
|
||||||
|
Flux.back!(m::Model, Δ, x) = first(Flux.back!(m, batchone(Δ), batchone(x)))
|
||||||
|
|
||||||
|
Flux.update!(m::Model, η) = (update!(m.last, η); m)
|
||||||
|
|
||||||
# MX FeedForward interface
|
# MX FeedForward interface
|
||||||
|
|
||||||
type SoftmaxOutput
|
type SoftmaxOutput
|
||||||
|
@ -6,11 +6,11 @@ Flux.loadmx()
|
|||||||
xs = rand(20)
|
xs = rand(20)
|
||||||
d = Affine(20, 10)
|
d = Affine(20, 10)
|
||||||
|
|
||||||
dm = mxnet(d, (1, 20))
|
dm = mxnet(d)
|
||||||
@test d(xs) ≈ dm(xs)
|
@test d(xs) ≈ dm(xs)
|
||||||
|
|
||||||
m = Multi(20, 15)
|
m = Multi(20, 15)
|
||||||
mm = mxnet(m, (1, 20))
|
mm = mxnet(m)
|
||||||
@test all(isapprox.(mm(xs), m(xs)))
|
@test all(isapprox.(mm(xs), m(xs)))
|
||||||
|
|
||||||
@testset "Backward Pass" begin
|
@testset "Backward Pass" begin
|
||||||
@ -40,7 +40,8 @@ end
|
|||||||
@testset "Stack Traces" begin
|
@testset "Stack Traces" begin
|
||||||
model = TLP(Affine(10, 20), Affine(21, 15))
|
model = TLP(Affine(10, 20), Affine(21, 15))
|
||||||
info("The following warning is normal")
|
info("The following warning is normal")
|
||||||
e = try mxnet(model, (10, 1))
|
dm = mxnet(model)
|
||||||
|
e = try dm(rand(10))
|
||||||
catch e e end
|
catch e e end
|
||||||
|
|
||||||
@test isa(e, DataFlow.Interpreter.Exception)
|
@test isa(e, DataFlow.Interpreter.Exception)
|
||||||
|
Loading…
Reference in New Issue
Block a user