From 94cb98c13fee6ea42a3a15ec8bddfd4ba2b2e707 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Sat, 28 Jan 2017 22:32:49 +0530 Subject: [PATCH] basic mxnet backend --- src/backend/backend.jl | 12 ++++- src/backend/mxnet/graph.jl | 73 +++++++++++++++++++++++++++ src/backend/mxnet/model.jl | 101 +++++++++++++++++++++++++++++++++++++ src/backend/mxnet/mxnet.jl | 10 ++++ 4 files changed, 195 insertions(+), 1 deletion(-) create mode 100644 src/backend/mxnet/graph.jl create mode 100644 src/backend/mxnet/model.jl create mode 100644 src/backend/mxnet/mxnet.jl diff --git a/src/backend/backend.jl b/src/backend/backend.jl index 372148f9..7a1185b5 100644 --- a/src/backend/backend.jl +++ b/src/backend/backend.jl @@ -1,4 +1,4 @@ -export tf +export tf, mxnet function loadtf() isdefined(Flux, :TF) && return @@ -9,3 +9,13 @@ function tf(args...) loadtf() TF.tf(args...) end + +function loadmx() + isdefined(Flux, :MX) && return + @eval include(joinpath(dirname($@__FILE__), "mxnet/mxnet.jl")) +end + +function mxnet(args...) + loadmx() + MX.mxnet(args...) +end diff --git a/src/backend/mxnet/graph.jl b/src/backend/mxnet/graph.jl new file mode 100644 index 00000000..183806bc --- /dev/null +++ b/src/backend/mxnet/graph.jl @@ -0,0 +1,73 @@ +using Base: @get! +using DataFlow: Constant, constant, Context, interpret, Split, + interpv, ituple, ilambda, iconst, iline, stack, mux +using Flux: imap + +# TODO: implement Julia's type promotion rules + +node(x::Tuple) = map(node, x) +node(x::mx.SymbolicNode) = x +# node(x::Number) = TensorFlow.constant(Float32(x)) + +graph(::typeof(tuple), args...) = (args...,) +graph(s::Split, t::Tuple) = t[s.n] +graph(::typeof(*), args...) = mx.dot(reverse(args)...) +graph(::typeof(+), args...) = mx.broadcast_plus(args...) +graph(::typeof(σ), x) = mx.Activation(data = x, act_type = :sigmoid) +graph(::typeof(relu), x) = mx.Activation(data = x, act_type=:relu) +graph(::typeof(tanh), x) = mx.Activation(data = x, act_type=:tanh) +graph(::typeof(flatten), x) = mx.Flatten(data = x) + +graph(::typeof(softmax), xs) = + mx.broadcast_div(exp(xs), mx.Reshape(mx.sum(exp(xs)), shape = (1,1))) + +graph(::typeof(cat), dim::Integer, a...) = mx.Concat(a..., dim = dim) +graph(::typeof(vcat), a...) = node(cat, 1, a...) + +graph(::Input, x) = x + +# graph(vars, c::Conv, x) = +# mx.Convolution(data = x, +# kernel = c.size, +# num_filter = c.features, +# stride = c.stride) +# +# graph(vars, p::MaxPool, x) = +# mx.Pooling(data = x, +# pool_type = :max, +# kernel = p.size, +# stride = p.stride) +# +# graph(vars, d::Dense, x) = +# mx.FullyConnected(data = x, +# num_hidden = size(d.W.x, 1), +# weight = graph(vars, d.W), +# bias = graph(vars, d.b)) + +function interp{T<:AArray}(ctx, p::Constant{Flux.Param{T}}) + id = gensym() + ctx[:params][id] = p.value.x + return mx.Variable(id) +end + +interp(ctx, p::Constant) = node(p.value) + +function graph(ctx::Context, model, args...) + node = graph(model, interpv(ctx, args)...) + # isa(node, Tensor) && (ctx[:stacks][node.op.name] = stack(ctx)) + return node +end + +function interp(ctx, model, args...) + g = Flux.graph(model) + g == nothing && return graph(ctx, model, args...) + DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.") + interpret(ctx, g, interpv(ctx, args)...) +end + +function tograph(model, args...) + ctx = Context(mux(iline, ilambda, ituple, imap, interp), + params = Dict(), stacks = Dict()) + out = interp(ctx, model, map(constant, args)...) + return ctx[:params], ctx[:stacks], out +end diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl new file mode 100644 index 00000000..fe27d8d9 --- /dev/null +++ b/src/backend/mxnet/model.jl @@ -0,0 +1,101 @@ +using MacroTools + +type MXModel <: Model + model::Any + params::Dict{Symbol,Any} + grads::Dict{Symbol,Any} + exec::mx.Executor +end + +mxdims(dims::NTuple) = reverse(dims) + +mxdims(n::Integer) = mxdims((n,)) + +function tond!(nd::mx.NDArray, xs::AArray) + mx.copy_ignore_shape!(nd, xs') + nd +end + +tond(xs::AArray) = tond!(mx.zeros(mxdims(size(xs))), xs) + +fromnd(xs::mx.NDArray) = copy(xs)' + +ndzero!(xs::mx.NDArray) = copy!(xs, mx.zeros(size(xs))) + +function mxargs(args) + map(args) do kv + arg, value = kv + arg => tond(value) + end +end + +function mxgrads(mxargs) + map(mxargs) do kv + arg, value = kv + arg => mx.zeros(size(value)) + end +end + +function loadparams!(model::MXModel) + for (name, arr) in model.exec.arg_dict + haskey(model.params, name) && tond!(arr, model.params[name]) + end + return model +end + +function mxnet(model::Model, input) + params, stacks, node = tograph(model, mx.Variable(:input)) + args = merge(mxargs(params), Dict(:input => mx.zeros(mxdims(input)))) + grads = mxgrads(args) + model = MXModel(model, params, grads, + mx.bind(node, args = args, + args_grad = grads, + grad_req = mx.GRAD_ADD)) + loadparams!(model) + return model +end + +function (model::MXModel)(input) + tond!(model.exec.arg_dict[:input], input) + mx.forward(model.exec, is_train = true) + fromnd(model.exec.outputs[1]) +end + +function Flux.back!(model::MXModel, Δ, x) + ndzero!(model.grads[:input]) + mx.backward(model.exec, tond(Δ)) + fromnd(model.grads[:input]) +end + +function Flux.update!(model::MXModel, η) + for (arg, grad) in zip(model.exec.arg_arrays, model.exec.grad_arrays) + mx.@nd_as_jl rw = (arg, grad) begin + arg .-= grad .* η + grad[:] = 0 + end + end + return model +end + +# MX FeedForward interface + +type SoftmaxOutput + name::Symbol +end + +graph(s::SoftmaxOutput, xs) = mx.SoftmaxOutput(data = xs, name = s.name) + +function rewrite_softmax(model, name) + model == softmax && return SoftmaxOutput(name) + g = Flux.graph(model) + (g == nothing || value(g) ≠ softmax || DataFlow.nin(g) ≠ 1) && error("mx.FeedForward models must end with `softmax`") + return Flux.Capacitor(vertex(SoftmaxOutput(name), g[1])) +end + +function mx.FeedForward(model::Model; input = :data, label = :softmax, context = mx.cpu()) + model = rewrite_softmax(model, label) + node, vars = mxgraph(model, input) + ff = mx.FeedForward(node, context = context) + ff.arg_params = mxargs(vars) + return ff +end diff --git a/src/backend/mxnet/mxnet.jl b/src/backend/mxnet/mxnet.jl new file mode 100644 index 00000000..8f8be89a --- /dev/null +++ b/src/backend/mxnet/mxnet.jl @@ -0,0 +1,10 @@ +module MX + +using MXNet, DataFlow, ..Flux + +export mxnet + +include("graph.jl") +include("model.jl") + +end