diff --git a/src/backend/mxnet/graph.jl b/src/backend/mxnet/graph.jl index c60f1db6..ec9b008d 100644 --- a/src/backend/mxnet/graph.jl +++ b/src/backend/mxnet/graph.jl @@ -38,8 +38,8 @@ graph(ctx::Context, d::Affine, x) = register(ctx, mx.FullyConnected(x, num_hidden = size(d.W.x, 2), - weight = var(ctx, d.W, size(d.W)), - bias = var(ctx, d.b, size(d.b, 2)))) + weight = var(ctx, AlterParam(d.W, false, false)), + bias = var(ctx, AlterParam(d.b, true, false)))) # TODO: use actual params} graph(ctx::Context, c::Conv2D, x) = @@ -61,9 +61,9 @@ end register(ctx::Context, node) = node -function var(ctx::Context, p::Flux.Param, size = nothing) +function var(ctx::Context, p) id = gensym() - ctx[:params][id] = size == nothing ? rebatch_last(p.x) : reshape(p.x, size...) + ctx[:params][id] = p return mx.Variable(id) end diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl index 10f3df17..835fed7b 100644 --- a/src/backend/mxnet/model.jl +++ b/src/backend/mxnet/model.jl @@ -1,11 +1,39 @@ using Flux: batchone, rebatch +# MNet batches on last dimension +rebatch_last(xs) = permutedims(xs, (2:ndims(xs)..., 1)) +rebatch_first(xs) = permutedims(xs, (ndims(xs), 1:ndims(xs)-1...)) + +paramvalue(p) = rebatch_last(p) +paramvalue(p::Flux.Param) = paramvalue(p.x) + +# Basically a kludge to make Affine work +# Hopefully will go away with more inference +type AlterParam + param::Flux.Param + strip::Bool + rebatch::Bool +end + +function paramvalue(p::AlterParam) + val = p.rebatch ? paramvalue(p.param) : p.param.x + p.strip ? squeeze(val, 1) : val +end + type Graph node::mx.SymbolicNode params::Dict{Symbol,Any} stacks::Dict{Any,Any} end +function mxparams(g::Graph) + params = Dict{Symbol,mx.NDArray}() + for (name, param) in g.params + params[name] = mx.zeros(size(paramvalue(param))) + end + return params +end + type Model <: Flux.Model model::Any graph::Graph @@ -13,37 +41,17 @@ type Model <: Flux.Model exec::mx.Executor end -tond(xs::AArray) = copy!(mx.zeros(size(xs)), xs) - -ndzero!(xs::mx.NDArray) = copy!(xs, mx.zeros(size(xs))) - -function mxargs(args) - isempty(args) && return Dict{Symbol,mx.NDArray}() - map(args) do kv - arg, value = kv - arg => tond(value) - end -end - -function mxgrads(mxargs) - isempty(mxargs) && return Dict{Symbol,mx.NDArray}() - map(mxargs) do kv - arg, value = kv - arg => mx.zeros(size(value)) - end -end - function loadparams!(model::Model) for (name, arr) in model.exec.arg_dict - haskey(model.graph.params, name) && copy!(arr, model.graph.params[name]) + haskey(model.graph.params, name) && copy!(arr, paramvalue(model.graph.params[name])) end return model end function mxnet(model::Flux.Model, input) graph = tograph(model, mx.Variable(:input)) - args = merge(mxargs(graph.params), Dict(:input => mx.zeros(input))) - grads = mxgrads(args) + args = merge(mxparams(graph), Dict(:input => mx.zeros(input))) + grads = mxparams(graph) model = @mxerr graph.stacks Model(model, graph, grads, mx.bind(graph.node, args = args, args_grad = grads, @@ -52,10 +60,6 @@ function mxnet(model::Flux.Model, input) return model end -# MNet batches on last dimension -rebatch_last(xs) = permutedims(xs, (2:ndims(xs)..., 1)) -rebatch_first(xs) = permutedims(xs, (ndims(xs), 1:ndims(xs)-1...)) - function runmodel(model::Model, input) copy!(model.exec.arg_dict[:input], input) mx.forward(model.exec, is_train = true) @@ -101,6 +105,6 @@ function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, cont model = rewrite_softmax(model, label) graph = tograph(model, mx.Variable(input), feedforward=true) ff = mx.FeedForward(graph.node, context = context) - isempty(graph.params) || (ff.arg_params = mxargs(graph.params)) + isempty(graph.params) || (ff.arg_params = mxparams(graph)) return ff end