From fe0bddd98d8c1e15465ed425074a802ea556659c Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 9 Jun 2017 00:55:54 +0100 Subject: [PATCH 1/3] pass args correctly --- src/backend/backend.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/backend/backend.jl b/src/backend/backend.jl index 23ea8629..ebf80ee0 100644 --- a/src/backend/backend.jl +++ b/src/backend/backend.jl @@ -14,7 +14,7 @@ end function tf(args...) loadtf() - eval(:(TF.tf($(args...)))) + eval(:(TF.tf($(QuoteNode.(args)...)))) end function loadmx() @@ -22,7 +22,7 @@ function loadmx() @eval include(joinpath(dirname($@__FILE__), "mxnet/mxnet.jl")) end -function mxnet(m) +function mxnet(args...) loadmx() - eval(:(MX.mxnet($m))) + eval(:(MX.mxnet($(QuoteNode.(args)...)))) end From 020ae616ccf5c8c5dd9545e07b3d04e69ab3f0f8 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 9 Jun 2017 00:56:52 +0100 Subject: [PATCH 2/3] custom mxnet context --- src/backend/mxnet/model.jl | 32 ++++++++++++++++++++------------ src/backend/mxnet/mxarray.jl | 3 ++- 2 files changed, 22 insertions(+), 13 deletions(-) diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl index 15bd69d5..19fe53a1 100644 --- a/src/backend/mxnet/model.jl +++ b/src/backend/mxnet/model.jl @@ -13,10 +13,10 @@ struct Graph stacks::Dict{Any,Any} end -function mxparams(ps) +function mxparams(ps, ctx) params = Dict{Symbol,MXArray}() for (name, param) in ps - params[name] = MXArray(size(param)) + params[name] = MXArray(size(param), ctx) end return params end @@ -25,6 +25,7 @@ ndparams(d) = Dict{Symbol,mx.NDArray}(k => v.data for (k, v) in d) struct Exec graph::Graph + ctx::mx.Context exec::mx.Executor args::Dict{Symbol,MXArray} grads::Dict{Symbol,MXArray} @@ -41,16 +42,17 @@ mxungroup(x::Tuple, outs) = map(x -> mxungroup(x, outs), x) dictt(xs, ys) = Dict(zip(collectt(xs), collectt(ys))) -function executor(graph::Graph, input...) +function executor(graph::Graph, input...; ctx = mx.cpu()) shapecheckt(graph.input, input) - args = merge(mxparams(graph.params), dictt(graph.input, mapt(d->MXArray(size(d)), input))) + args = merge(mxparams(graph.params, ctx), dictt(graph.input, mapt(d->MXArray(size(d), ctx), input))) grads = filter((a, b) -> b isa Flux.Param, graph.params) - grads = merge(mxparams(grads), dictt(graph.input, mapt(d->MXArray(size(d)), input))) + grads = merge(mxparams(grads, ctx), dictt(graph.input, mapt(d->MXArray(size(d), ctx), input))) exec = mx.bind(mxgroup(graph.output), + context = ctx, args = ndparams(args), args_grad = ndparams(grads), grad_req = mx.GRAD_ADD) - exec = Exec(graph, exec, args, grads, MXArray.(exec.outputs)) + exec = Exec(graph, ctx, exec, args, grads, MXArray.(exec.outputs)) loadparams!(exec) return exec end @@ -63,7 +65,7 @@ end function Flux.back!(exec::Exec, Δ) mapt(k -> exec.grads[k][:] = 0, exec.graph.input) - mx.backward(exec.exec, map(x -> MXArray(x).data, collectt(Δ))) + mx.backward(exec.exec, map(x -> MXArray(x, exec.ctx).data, collectt(Δ))) mapt(k -> copy(exec.grads[k]), exec.graph.input) end @@ -79,22 +81,28 @@ function Flux.update!(exec::Exec, η) return exec end +toctx(ctx::mx.Context) = ctx +toctx(c::Symbol) = c == :gpu ? mx.gpu() : mx.cpu() + # TODO: if `last` changes, update params appropriately mutable struct Model model::Any + ctx::mx.Context execs::Dict{Tuple,Exec} graph::Graph last::Exec - Model(model) = new(model, Dict()) + Model(model, ctx) = new(model, ctx, Dict()) end -mxnet(model) = Model(model) +mxnet(model, ctx = :cpu) = Model(model, toctx(ctx)) import Base: @get! # TODO: dims having its own type would be useful -executor(m::Model, input...) = @get!(m.execs, mapt(size, input), executor(m.graph, input...)) +executor(m::Model, input...) = + @get!(m.execs, mapt(size, input), + executor(m.graph, input...; ctx = m.ctx)) function (m::Model)(xs...) @mxerr m.graph.stacks begin @@ -134,10 +142,10 @@ function rewrite_softmax(model, name) return Flux.Capacitor(vertex(SoftmaxOutput(name), g[1])) end -function FeedForward(model; input = :data, label = :softmax, context = mx.cpu()) +function FeedForward(model; input = :data, label = :softmax, ctx = mx.cpu()) model = rewrite_softmax(model, label) graph = tograph(model, input, feedforward=true) ff = mx.FeedForward(graph.output, context = context) - isempty(graph.params) || (ff.arg_params = ndparams(mxparams(graph.params))) + isempty(graph.params) || (ff.arg_params = ndparams(mxparams(graph.params, ctx))) return ff end diff --git a/src/backend/mxnet/mxarray.jl b/src/backend/mxnet/mxarray.jl index 67dd6503..f89ed734 100644 --- a/src/backend/mxnet/mxarray.jl +++ b/src/backend/mxnet/mxarray.jl @@ -12,7 +12,8 @@ end MXArray(data::mx.NDArray) = MXArray(data, Array{Float32}(size(data))) -MXArray(dims::Dims) = MXArray(mx.zeros(reverse(dims))) +# TODO: split cpu/gpu mxarrays +MXArray(dims::Dims, ctx = mx.cpu()) = MXArray(mx.zeros(reverse(dims), ctx)) Base.size(xs::MXArray) = reverse(size(xs.data)) From 1cc8100456c01ae02c714980bc7a8d2c9b08bf72 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 9 Jun 2017 01:35:07 +0100 Subject: [PATCH 3/3] ctx methods for seq models --- src/backend/mxnet/model.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl index 19fe53a1..6a09d44f 100644 --- a/src/backend/mxnet/model.jl +++ b/src/backend/mxnet/model.jl @@ -124,8 +124,8 @@ Flux.update!(m::Model, η) = (update!(m.last, η); m) using Flux: Stateful, SeqModel -mxnet(m::Stateful) = Stateful(mxnet(m.model), m.states, m.istate, m.ostate) -mxnet(m::SeqModel) = SeqModel(mxnet(m.model), m.steps) +mxnet(m::Stateful, a...) = Stateful(mxnet(m.model, a...), m.states, m.istate, m.ostate) +mxnet(m::SeqModel, a...) = SeqModel(mxnet(m.model, a...), m.steps) # MX FeedForward interface