custom mxnet context

This commit is contained in:
Mike J Innes 2017-06-09 00:56:52 +01:00
parent fe0bddd98d
commit 020ae616cc
2 changed files with 22 additions and 13 deletions

View File

@ -13,10 +13,10 @@ struct Graph
stacks::Dict{Any,Any} stacks::Dict{Any,Any}
end end
function mxparams(ps) function mxparams(ps, ctx)
params = Dict{Symbol,MXArray}() params = Dict{Symbol,MXArray}()
for (name, param) in ps for (name, param) in ps
params[name] = MXArray(size(param)) params[name] = MXArray(size(param), ctx)
end end
return params return params
end end
@ -25,6 +25,7 @@ ndparams(d) = Dict{Symbol,mx.NDArray}(k => v.data for (k, v) in d)
struct Exec struct Exec
graph::Graph graph::Graph
ctx::mx.Context
exec::mx.Executor exec::mx.Executor
args::Dict{Symbol,MXArray} args::Dict{Symbol,MXArray}
grads::Dict{Symbol,MXArray} grads::Dict{Symbol,MXArray}
@ -41,16 +42,17 @@ mxungroup(x::Tuple, outs) = map(x -> mxungroup(x, outs), x)
dictt(xs, ys) = Dict(zip(collectt(xs), collectt(ys))) dictt(xs, ys) = Dict(zip(collectt(xs), collectt(ys)))
function executor(graph::Graph, input...) function executor(graph::Graph, input...; ctx = mx.cpu())
shapecheckt(graph.input, input) shapecheckt(graph.input, input)
args = merge(mxparams(graph.params), dictt(graph.input, mapt(d->MXArray(size(d)), input))) args = merge(mxparams(graph.params, ctx), dictt(graph.input, mapt(d->MXArray(size(d), ctx), input)))
grads = filter((a, b) -> b isa Flux.Param, graph.params) grads = filter((a, b) -> b isa Flux.Param, graph.params)
grads = merge(mxparams(grads), dictt(graph.input, mapt(d->MXArray(size(d)), input))) grads = merge(mxparams(grads, ctx), dictt(graph.input, mapt(d->MXArray(size(d), ctx), input)))
exec = mx.bind(mxgroup(graph.output), exec = mx.bind(mxgroup(graph.output),
context = ctx,
args = ndparams(args), args = ndparams(args),
args_grad = ndparams(grads), args_grad = ndparams(grads),
grad_req = mx.GRAD_ADD) grad_req = mx.GRAD_ADD)
exec = Exec(graph, exec, args, grads, MXArray.(exec.outputs)) exec = Exec(graph, ctx, exec, args, grads, MXArray.(exec.outputs))
loadparams!(exec) loadparams!(exec)
return exec return exec
end end
@ -63,7 +65,7 @@ end
function Flux.back!(exec::Exec, Δ) function Flux.back!(exec::Exec, Δ)
mapt(k -> exec.grads[k][:] = 0, exec.graph.input) mapt(k -> exec.grads[k][:] = 0, exec.graph.input)
mx.backward(exec.exec, map(x -> MXArray(x).data, collectt(Δ))) mx.backward(exec.exec, map(x -> MXArray(x, exec.ctx).data, collectt(Δ)))
mapt(k -> copy(exec.grads[k]), exec.graph.input) mapt(k -> copy(exec.grads[k]), exec.graph.input)
end end
@ -79,22 +81,28 @@ function Flux.update!(exec::Exec, η)
return exec return exec
end end
toctx(ctx::mx.Context) = ctx
toctx(c::Symbol) = c == :gpu ? mx.gpu() : mx.cpu()
# TODO: if `last` changes, update params appropriately # TODO: if `last` changes, update params appropriately
mutable struct Model mutable struct Model
model::Any model::Any
ctx::mx.Context
execs::Dict{Tuple,Exec} execs::Dict{Tuple,Exec}
graph::Graph graph::Graph
last::Exec last::Exec
Model(model) = new(model, Dict()) Model(model, ctx) = new(model, ctx, Dict())
end end
mxnet(model) = Model(model) mxnet(model, ctx = :cpu) = Model(model, toctx(ctx))
import Base: @get! import Base: @get!
# TODO: dims having its own type would be useful # TODO: dims having its own type would be useful
executor(m::Model, input...) = @get!(m.execs, mapt(size, input), executor(m.graph, input...)) executor(m::Model, input...) =
@get!(m.execs, mapt(size, input),
executor(m.graph, input...; ctx = m.ctx))
function (m::Model)(xs...) function (m::Model)(xs...)
@mxerr m.graph.stacks begin @mxerr m.graph.stacks begin
@ -134,10 +142,10 @@ function rewrite_softmax(model, name)
return Flux.Capacitor(vertex(SoftmaxOutput(name), g[1])) return Flux.Capacitor(vertex(SoftmaxOutput(name), g[1]))
end end
function FeedForward(model; input = :data, label = :softmax, context = mx.cpu()) function FeedForward(model; input = :data, label = :softmax, ctx = mx.cpu())
model = rewrite_softmax(model, label) model = rewrite_softmax(model, label)
graph = tograph(model, input, feedforward=true) graph = tograph(model, input, feedforward=true)
ff = mx.FeedForward(graph.output, context = context) ff = mx.FeedForward(graph.output, context = context)
isempty(graph.params) || (ff.arg_params = ndparams(mxparams(graph.params))) isempty(graph.params) || (ff.arg_params = ndparams(mxparams(graph.params, ctx)))
return ff return ff
end end

View File

@ -12,7 +12,8 @@ end
MXArray(data::mx.NDArray) = MXArray(data, Array{Float32}(size(data))) MXArray(data::mx.NDArray) = MXArray(data, Array{Float32}(size(data)))
MXArray(dims::Dims) = MXArray(mx.zeros(reverse(dims))) # TODO: split cpu/gpu mxarrays
MXArray(dims::Dims, ctx = mx.cpu()) = MXArray(mx.zeros(reverse(dims), ctx))
Base.size(xs::MXArray) = reverse(size(xs.data)) Base.size(xs::MXArray) = reverse(size(xs.data))