transpose everything going into mxnet

This commit is contained in:
Mike J Innes 2017-03-08 17:35:15 +00:00
parent 3b004bac7d
commit 9d1d176749
3 changed files with 34 additions and 50 deletions

View File

@ -19,7 +19,7 @@ node(x::mx.SymbolicNode) = x
graph(::typeof(tuple), args...) = (args...,) graph(::typeof(tuple), args...) = (args...,)
graph(::typeof(+), args...) = mx.broadcast_plus(args...) graph(::typeof(+), args...) = mx.broadcast_plus(args...)
graph(::typeof(*), x, W) = mx.dot(W, x) # Adjustments for batching graph(::typeof(*), xs...) = mx.dot(reverse(xs)...) # Work around MXNet shape hack
graph(::typeof(σ), x) = mx.Activation(x, act_type = :sigmoid) graph(::typeof(σ), x) = mx.Activation(x, act_type = :sigmoid)
graph(::typeof(relu), x) = mx.Activation(x, act_type = :relu) graph(::typeof(relu), x) = mx.Activation(x, act_type = :relu)
graph(::typeof(tanh), x) = mx.Activation(x, act_type = :tanh) graph(::typeof(tanh), x) = mx.Activation(x, act_type = :tanh)
@ -38,10 +38,10 @@ graph(ctx::Context, d::Affine, x) =
register(ctx, register(ctx,
mx.FullyConnected(mx.SymbolicNode, data = x, mx.FullyConnected(mx.SymbolicNode, data = x,
num_hidden = size(d.W.x, 2), num_hidden = size(d.W.x, 2),
weight = var(ctx, AlterParam(d.W, false, false)), weight = var(ctx, AlterParam(d.W, x->x', nothing)),
bias = var(ctx, AlterParam(d.b, true, false)))) bias = var(ctx, AlterParam(d.b, x->squeeze(x, 1), nothing))))
# TODO: use actual params} # TODO: use actual params
graph(ctx::Context, c::Conv2D, x) = graph(ctx::Context, c::Conv2D, x) =
mx.Convolution(x, mx.Convolution(x,
kernel = size(c.filter, 1, 2), kernel = size(c.filter, 1, 2),

View File

@ -1,25 +1,12 @@
using Flux: batchone, unbatchone, rebatch using Flux: batchone, unbatchone, rebatch
# MNet batches on last dimension
rebatch_last(xs) = permutedims(xs, (2:ndims(xs)..., 1))
rebatch_first(xs) = permutedims(xs, (ndims(xs), 1:ndims(xs)-1...))
rebatch_first(xs::Tuple) = rebatch_first.(xs)
paramvalue(p) = rebatch_last(p)
paramvalue(p::Flux.Param) = paramvalue(p.x)
# Basically a kludge to make Affine work
# Hopefully will go away with more inference
type AlterParam type AlterParam
param::Flux.Param param
strip::Bool load
rebatch::Bool store
end end
function paramvalue(p::AlterParam) Base.size(p::AlterParam) = size(p.load(p.param.x))
val = p.rebatch ? paramvalue(p.param) : p.param.x
p.strip ? squeeze(val, 1) : val
end
type Graph type Graph
output output
@ -28,34 +15,32 @@ type Graph
end end
function mxparams(g::Graph) function mxparams(g::Graph)
params = Dict{Symbol,mx.NDArray}() params = Dict{Symbol,MXArray}()
for (name, param) in g.params for (name, param) in g.params
params[name] = mx.zeros(size(paramvalue(param))) params[name] = MXArray(size(param))
end end
return params return params
end end
function loadparams!(g::Graph, args) function copyargs!(as, bs)
for (id, param) in g.params for id in intersect(keys(as), keys(bs))
haskey(args, id) && copy!(args[id], paramvalue(param)) copy!(as[id], bs[id])
end end
end end
function storeparams!(g::Graph, args) ndparams(d::Dict{Symbol,MXArray}) = Dict(k => v.data for (k, v) in d)
for (id, param) in g.params
haskey(args, id) && copy!(param.x, rebatch_first(copy(args[id])))
end
end
type Model <: Flux.Model type Model <: Flux.Model
model::Any model::Any
graph::Graph graph::Graph
grads::Dict{Symbol,Any} args::Dict{Symbol,MXArray}
grads::Dict{Symbol,MXArray}
outs::Vector{MXArray}
exec::mx.Executor exec::mx.Executor
end end
loadparams!(model::Model) = loadparams!(model.graph, model.exec.arg_dict) loadparams!(model::Model) = copyargs!(model.args, model.graph.params)
storeparams!(model::Model) = storeparams!(model.graph, model.exec.arg_dict) storeparams!(model::Model) = copyargs!(model.graph.params, model.args)
mxgroup(x) = x mxgroup(x) = x
mxgroup(x::Tuple) = mx.Group(mxgroup.(x)...) mxgroup(x::Tuple) = mx.Group(mxgroup.(x)...)
@ -64,35 +49,34 @@ mxungroup(x::Tuple, outs) = map(x -> mxungroup(x, outs), x)
function mxnet(model::Flux.Model, input) function mxnet(model::Flux.Model, input)
graph = tograph(model, mx.Variable(:input)) graph = tograph(model, mx.Variable(:input))
args = merge(mxparams(graph), Dict(:input => mx.zeros(input))) args = merge(mxparams(graph), Dict(:input => MXArray(input)))
grads = merge(mxparams(graph), Dict(:input => mx.zeros(input))) grads = merge(mxparams(graph), Dict(:input => MXArray(input)))
model = @mxerr graph.stacks Model(model, graph, grads, exec = @mxerr graph.stacks mx.bind(mxgroup(graph.output),
mx.bind(mxgroup(graph.output), args = args, args = ndparams(args),
args_grad = grads, args_grad = ndparams(grads),
grad_req = mx.GRAD_ADD)) grad_req = mx.GRAD_ADD)
model = Model(model, graph, args, grads, MXArray.(exec.outputs), exec)
loadparams!(model) loadparams!(model)
return model return model
end end
function runmodel(model::Model, input) function runmodel(model::Model, input)
copy!(model.exec.arg_dict[:input], input) copy!(model.args[:input], input)
mx.forward(model.exec, is_train = true) mx.forward(model.exec, is_train = true)
mxungroup(model.graph.output, copy(model.exec.outputs)) mxungroup(model.graph.output, copy(model.outs))
end end
(m::Model)(x::Batch) = rebatch(rebatch_first(runmodel(m, rebatch_last(rawbatch(x))))) (m::Model)(x::Batch) = rebatch(runmodel(m, rawbatch(x)))
(m::Model)(x) = unbatchone(m(batchone(x))) (m::Model)(x) = unbatchone(m(batchone(x)))
tond(xs::AArray) = copy!(mx.zeros(size(xs)), xs)
function runback!(model::Model, Δ) function runback!(model::Model, Δ)
model.grads[:input][:] = 0 model.grads[:input][:] = 0
mx.backward(model.exec, tond(Δ)) mx.backward(model.exec, MXArray(Δ).data)
copy(model.grads[:input]) copy(model.grads[:input])
end end
Flux.back!(m::Model, Δ::Batch, x) = rebatch(rebatch_first(runback!(m, rebatch_last(rawbatch(Δ))))) Flux.back!(m::Model, Δ::Batch, x) = rebatch(runback!(m, rawbatch(Δ)))
Flux.back!(m::Model, Δ, x) = first(Flux.back!(m, batchone(Δ), x)) Flux.back!(m::Model, Δ, x) = first(Flux.back!(m, batchone(Δ), x))
@ -126,6 +110,6 @@ function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, cont
model = rewrite_softmax(model, label) model = rewrite_softmax(model, label)
graph = tograph(model, mx.Variable(input), feedforward=true) graph = tograph(model, mx.Variable(input), feedforward=true)
ff = mx.FeedForward(graph.output, context = context) ff = mx.FeedForward(graph.output, context = context)
isempty(graph.params) || (ff.arg_params = mxparams(graph)) isempty(graph.params) || (ff.arg_params = ndparams(mxparams(graph)))
return ff return ff
end end

View File

@ -6,11 +6,11 @@ Flux.loadmx()
xs = rand(20) xs = rand(20)
d = Affine(20, 10) d = Affine(20, 10)
dm = mxnet(d, (20, 1)) dm = mxnet(d, (1, 20))
@test d(xs) dm(xs) @test d(xs) dm(xs)
m = Multi(20, 15) m = Multi(20, 15)
mm = mxnet(m, (20, 1)) mm = mxnet(m, (1, 20))
@test all(isapprox.(mm(xs), m(xs))) @test all(isapprox.(mm(xs), m(xs)))
@testset "Backward Pass" begin @testset "Backward Pass" begin