From 020ae616ccf5c8c5dd9545e07b3d04e69ab3f0f8 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 9 Jun 2017 00:56:52 +0100 Subject: [PATCH] custom mxnet context --- src/backend/mxnet/model.jl | 32 ++++++++++++++++++++------------ src/backend/mxnet/mxarray.jl | 3 ++- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl index 15bd69d5..19fe53a1 100644 --- a/src/backend/mxnet/model.jl +++ b/src/backend/mxnet/model.jl @@ -13,10 +13,10 @@ struct Graph stacks::Dict{Any,Any} end -function mxparams(ps) +function mxparams(ps, ctx) params = Dict{Symbol,MXArray}() for (name, param) in ps - params[name] = MXArray(size(param)) + params[name] = MXArray(size(param), ctx) end return params end @@ -25,6 +25,7 @@ ndparams(d) = Dict{Symbol,mx.NDArray}(k => v.data for (k, v) in d) struct Exec graph::Graph + ctx::mx.Context exec::mx.Executor args::Dict{Symbol,MXArray} grads::Dict{Symbol,MXArray} @@ -41,16 +42,17 @@ mxungroup(x::Tuple, outs) = map(x -> mxungroup(x, outs), x) dictt(xs, ys) = Dict(zip(collectt(xs), collectt(ys))) -function executor(graph::Graph, input...) +function executor(graph::Graph, input...; ctx = mx.cpu()) shapecheckt(graph.input, input) - args = merge(mxparams(graph.params), dictt(graph.input, mapt(d->MXArray(size(d)), input))) + args = merge(mxparams(graph.params, ctx), dictt(graph.input, mapt(d->MXArray(size(d), ctx), input))) grads = filter((a, b) -> b isa Flux.Param, graph.params) - grads = merge(mxparams(grads), dictt(graph.input, mapt(d->MXArray(size(d)), input))) + grads = merge(mxparams(grads, ctx), dictt(graph.input, mapt(d->MXArray(size(d), ctx), input))) exec = mx.bind(mxgroup(graph.output), + context = ctx, args = ndparams(args), args_grad = ndparams(grads), grad_req = mx.GRAD_ADD) - exec = Exec(graph, exec, args, grads, MXArray.(exec.outputs)) + exec = Exec(graph, ctx, exec, args, grads, MXArray.(exec.outputs)) loadparams!(exec) return exec end @@ -63,7 +65,7 @@ end function Flux.back!(exec::Exec, Δ) mapt(k -> exec.grads[k][:] = 0, exec.graph.input) - mx.backward(exec.exec, map(x -> MXArray(x).data, collectt(Δ))) + mx.backward(exec.exec, map(x -> MXArray(x, exec.ctx).data, collectt(Δ))) mapt(k -> copy(exec.grads[k]), exec.graph.input) end @@ -79,22 +81,28 @@ function Flux.update!(exec::Exec, η) return exec end +toctx(ctx::mx.Context) = ctx +toctx(c::Symbol) = c == :gpu ? mx.gpu() : mx.cpu() + # TODO: if `last` changes, update params appropriately mutable struct Model model::Any + ctx::mx.Context execs::Dict{Tuple,Exec} graph::Graph last::Exec - Model(model) = new(model, Dict()) + Model(model, ctx) = new(model, ctx, Dict()) end -mxnet(model) = Model(model) +mxnet(model, ctx = :cpu) = Model(model, toctx(ctx)) import Base: @get! # TODO: dims having its own type would be useful -executor(m::Model, input...) = @get!(m.execs, mapt(size, input), executor(m.graph, input...)) +executor(m::Model, input...) = + @get!(m.execs, mapt(size, input), + executor(m.graph, input...; ctx = m.ctx)) function (m::Model)(xs...) @mxerr m.graph.stacks begin @@ -134,10 +142,10 @@ function rewrite_softmax(model, name) return Flux.Capacitor(vertex(SoftmaxOutput(name), g[1])) end -function FeedForward(model; input = :data, label = :softmax, context = mx.cpu()) +function FeedForward(model; input = :data, label = :softmax, ctx = mx.cpu()) model = rewrite_softmax(model, label) graph = tograph(model, input, feedforward=true) ff = mx.FeedForward(graph.output, context = context) - isempty(graph.params) || (ff.arg_params = ndparams(mxparams(graph.params))) + isempty(graph.params) || (ff.arg_params = ndparams(mxparams(graph.params, ctx))) return ff end diff --git a/src/backend/mxnet/mxarray.jl b/src/backend/mxnet/mxarray.jl index 67dd6503..f89ed734 100644 --- a/src/backend/mxnet/mxarray.jl +++ b/src/backend/mxnet/mxarray.jl @@ -12,7 +12,8 @@ end MXArray(data::mx.NDArray) = MXArray(data, Array{Float32}(size(data))) -MXArray(dims::Dims) = MXArray(mx.zeros(reverse(dims))) +# TODO: split cpu/gpu mxarrays +MXArray(dims::Dims, ctx = mx.cpu()) = MXArray(mx.zeros(reverse(dims), ctx)) Base.size(xs::MXArray) = reverse(size(xs.data))