basic mxnet backend
This commit is contained in:
parent
3b3a088851
commit
94cb98c13f
@ -1,4 +1,4 @@
|
|||||||
export tf
|
export tf, mxnet
|
||||||
|
|
||||||
function loadtf()
|
function loadtf()
|
||||||
isdefined(Flux, :TF) && return
|
isdefined(Flux, :TF) && return
|
||||||
@ -9,3 +9,13 @@ function tf(args...)
|
|||||||
loadtf()
|
loadtf()
|
||||||
TF.tf(args...)
|
TF.tf(args...)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function loadmx()
|
||||||
|
isdefined(Flux, :MX) && return
|
||||||
|
@eval include(joinpath(dirname($@__FILE__), "mxnet/mxnet.jl"))
|
||||||
|
end
|
||||||
|
|
||||||
|
function mxnet(args...)
|
||||||
|
loadmx()
|
||||||
|
MX.mxnet(args...)
|
||||||
|
end
|
||||||
|
73
src/backend/mxnet/graph.jl
Normal file
73
src/backend/mxnet/graph.jl
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
using Base: @get!
|
||||||
|
using DataFlow: Constant, constant, Context, interpret, Split,
|
||||||
|
interpv, ituple, ilambda, iconst, iline, stack, mux
|
||||||
|
using Flux: imap
|
||||||
|
|
||||||
|
# TODO: implement Julia's type promotion rules
|
||||||
|
|
||||||
|
node(x::Tuple) = map(node, x)
|
||||||
|
node(x::mx.SymbolicNode) = x
|
||||||
|
# node(x::Number) = TensorFlow.constant(Float32(x))
|
||||||
|
|
||||||
|
graph(::typeof(tuple), args...) = (args...,)
|
||||||
|
graph(s::Split, t::Tuple) = t[s.n]
|
||||||
|
graph(::typeof(*), args...) = mx.dot(reverse(args)...)
|
||||||
|
graph(::typeof(+), args...) = mx.broadcast_plus(args...)
|
||||||
|
graph(::typeof(σ), x) = mx.Activation(data = x, act_type = :sigmoid)
|
||||||
|
graph(::typeof(relu), x) = mx.Activation(data = x, act_type=:relu)
|
||||||
|
graph(::typeof(tanh), x) = mx.Activation(data = x, act_type=:tanh)
|
||||||
|
graph(::typeof(flatten), x) = mx.Flatten(data = x)
|
||||||
|
|
||||||
|
graph(::typeof(softmax), xs) =
|
||||||
|
mx.broadcast_div(exp(xs), mx.Reshape(mx.sum(exp(xs)), shape = (1,1)))
|
||||||
|
|
||||||
|
graph(::typeof(cat), dim::Integer, a...) = mx.Concat(a..., dim = dim)
|
||||||
|
graph(::typeof(vcat), a...) = node(cat, 1, a...)
|
||||||
|
|
||||||
|
graph(::Input, x) = x
|
||||||
|
|
||||||
|
# graph(vars, c::Conv, x) =
|
||||||
|
# mx.Convolution(data = x,
|
||||||
|
# kernel = c.size,
|
||||||
|
# num_filter = c.features,
|
||||||
|
# stride = c.stride)
|
||||||
|
#
|
||||||
|
# graph(vars, p::MaxPool, x) =
|
||||||
|
# mx.Pooling(data = x,
|
||||||
|
# pool_type = :max,
|
||||||
|
# kernel = p.size,
|
||||||
|
# stride = p.stride)
|
||||||
|
#
|
||||||
|
# graph(vars, d::Dense, x) =
|
||||||
|
# mx.FullyConnected(data = x,
|
||||||
|
# num_hidden = size(d.W.x, 1),
|
||||||
|
# weight = graph(vars, d.W),
|
||||||
|
# bias = graph(vars, d.b))
|
||||||
|
|
||||||
|
function interp{T<:AArray}(ctx, p::Constant{Flux.Param{T}})
|
||||||
|
id = gensym()
|
||||||
|
ctx[:params][id] = p.value.x
|
||||||
|
return mx.Variable(id)
|
||||||
|
end
|
||||||
|
|
||||||
|
interp(ctx, p::Constant) = node(p.value)
|
||||||
|
|
||||||
|
function graph(ctx::Context, model, args...)
|
||||||
|
node = graph(model, interpv(ctx, args)...)
|
||||||
|
# isa(node, Tensor) && (ctx[:stacks][node.op.name] = stack(ctx))
|
||||||
|
return node
|
||||||
|
end
|
||||||
|
|
||||||
|
function interp(ctx, model, args...)
|
||||||
|
g = Flux.graph(model)
|
||||||
|
g == nothing && return graph(ctx, model, args...)
|
||||||
|
DataFlow.iscyclic(g) && error("This model has a cycle; try unrolling it first.")
|
||||||
|
interpret(ctx, g, interpv(ctx, args)...)
|
||||||
|
end
|
||||||
|
|
||||||
|
function tograph(model, args...)
|
||||||
|
ctx = Context(mux(iline, ilambda, ituple, imap, interp),
|
||||||
|
params = Dict(), stacks = Dict())
|
||||||
|
out = interp(ctx, model, map(constant, args)...)
|
||||||
|
return ctx[:params], ctx[:stacks], out
|
||||||
|
end
|
101
src/backend/mxnet/model.jl
Normal file
101
src/backend/mxnet/model.jl
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
using MacroTools
|
||||||
|
|
||||||
|
type MXModel <: Model
|
||||||
|
model::Any
|
||||||
|
params::Dict{Symbol,Any}
|
||||||
|
grads::Dict{Symbol,Any}
|
||||||
|
exec::mx.Executor
|
||||||
|
end
|
||||||
|
|
||||||
|
mxdims(dims::NTuple) = reverse(dims)
|
||||||
|
|
||||||
|
mxdims(n::Integer) = mxdims((n,))
|
||||||
|
|
||||||
|
function tond!(nd::mx.NDArray, xs::AArray)
|
||||||
|
mx.copy_ignore_shape!(nd, xs')
|
||||||
|
nd
|
||||||
|
end
|
||||||
|
|
||||||
|
tond(xs::AArray) = tond!(mx.zeros(mxdims(size(xs))), xs)
|
||||||
|
|
||||||
|
fromnd(xs::mx.NDArray) = copy(xs)'
|
||||||
|
|
||||||
|
ndzero!(xs::mx.NDArray) = copy!(xs, mx.zeros(size(xs)))
|
||||||
|
|
||||||
|
function mxargs(args)
|
||||||
|
map(args) do kv
|
||||||
|
arg, value = kv
|
||||||
|
arg => tond(value)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
function mxgrads(mxargs)
|
||||||
|
map(mxargs) do kv
|
||||||
|
arg, value = kv
|
||||||
|
arg => mx.zeros(size(value))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
function loadparams!(model::MXModel)
|
||||||
|
for (name, arr) in model.exec.arg_dict
|
||||||
|
haskey(model.params, name) && tond!(arr, model.params[name])
|
||||||
|
end
|
||||||
|
return model
|
||||||
|
end
|
||||||
|
|
||||||
|
function mxnet(model::Model, input)
|
||||||
|
params, stacks, node = tograph(model, mx.Variable(:input))
|
||||||
|
args = merge(mxargs(params), Dict(:input => mx.zeros(mxdims(input))))
|
||||||
|
grads = mxgrads(args)
|
||||||
|
model = MXModel(model, params, grads,
|
||||||
|
mx.bind(node, args = args,
|
||||||
|
args_grad = grads,
|
||||||
|
grad_req = mx.GRAD_ADD))
|
||||||
|
loadparams!(model)
|
||||||
|
return model
|
||||||
|
end
|
||||||
|
|
||||||
|
function (model::MXModel)(input)
|
||||||
|
tond!(model.exec.arg_dict[:input], input)
|
||||||
|
mx.forward(model.exec, is_train = true)
|
||||||
|
fromnd(model.exec.outputs[1])
|
||||||
|
end
|
||||||
|
|
||||||
|
function Flux.back!(model::MXModel, Δ, x)
|
||||||
|
ndzero!(model.grads[:input])
|
||||||
|
mx.backward(model.exec, tond(Δ))
|
||||||
|
fromnd(model.grads[:input])
|
||||||
|
end
|
||||||
|
|
||||||
|
function Flux.update!(model::MXModel, η)
|
||||||
|
for (arg, grad) in zip(model.exec.arg_arrays, model.exec.grad_arrays)
|
||||||
|
mx.@nd_as_jl rw = (arg, grad) begin
|
||||||
|
arg .-= grad .* η
|
||||||
|
grad[:] = 0
|
||||||
|
end
|
||||||
|
end
|
||||||
|
return model
|
||||||
|
end
|
||||||
|
|
||||||
|
# MX FeedForward interface
|
||||||
|
|
||||||
|
type SoftmaxOutput
|
||||||
|
name::Symbol
|
||||||
|
end
|
||||||
|
|
||||||
|
graph(s::SoftmaxOutput, xs) = mx.SoftmaxOutput(data = xs, name = s.name)
|
||||||
|
|
||||||
|
function rewrite_softmax(model, name)
|
||||||
|
model == softmax && return SoftmaxOutput(name)
|
||||||
|
g = Flux.graph(model)
|
||||||
|
(g == nothing || value(g) ≠ softmax || DataFlow.nin(g) ≠ 1) && error("mx.FeedForward models must end with `softmax`")
|
||||||
|
return Flux.Capacitor(vertex(SoftmaxOutput(name), g[1]))
|
||||||
|
end
|
||||||
|
|
||||||
|
function mx.FeedForward(model::Model; input = :data, label = :softmax, context = mx.cpu())
|
||||||
|
model = rewrite_softmax(model, label)
|
||||||
|
node, vars = mxgraph(model, input)
|
||||||
|
ff = mx.FeedForward(node, context = context)
|
||||||
|
ff.arg_params = mxargs(vars)
|
||||||
|
return ff
|
||||||
|
end
|
10
src/backend/mxnet/mxnet.jl
Normal file
10
src/backend/mxnet/mxnet.jl
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
module MX
|
||||||
|
|
||||||
|
using MXNet, DataFlow, ..Flux
|
||||||
|
|
||||||
|
export mxnet
|
||||||
|
|
||||||
|
include("graph.jl")
|
||||||
|
include("model.jl")
|
||||||
|
|
||||||
|
end
|
Loading…
Reference in New Issue
Block a user