transpose everything going into mxnet
This commit is contained in:
parent
3b004bac7d
commit
9d1d176749
|
@ -19,7 +19,7 @@ node(x::mx.SymbolicNode) = x
|
|||
|
||||
graph(::typeof(tuple), args...) = (args...,)
|
||||
graph(::typeof(+), args...) = mx.broadcast_plus(args...)
|
||||
graph(::typeof(*), x, W) = mx.dot(W, x) # Adjustments for batching
|
||||
graph(::typeof(*), xs...) = mx.dot(reverse(xs)...) # Work around MXNet shape hack
|
||||
graph(::typeof(σ), x) = mx.Activation(x, act_type = :sigmoid)
|
||||
graph(::typeof(relu), x) = mx.Activation(x, act_type = :relu)
|
||||
graph(::typeof(tanh), x) = mx.Activation(x, act_type = :tanh)
|
||||
|
@ -38,10 +38,10 @@ graph(ctx::Context, d::Affine, x) =
|
|||
register(ctx,
|
||||
mx.FullyConnected(mx.SymbolicNode, data = x,
|
||||
num_hidden = size(d.W.x, 2),
|
||||
weight = var(ctx, AlterParam(d.W, false, false)),
|
||||
bias = var(ctx, AlterParam(d.b, true, false))))
|
||||
weight = var(ctx, AlterParam(d.W, x->x', nothing)),
|
||||
bias = var(ctx, AlterParam(d.b, x->squeeze(x, 1), nothing))))
|
||||
|
||||
# TODO: use actual params}
|
||||
# TODO: use actual params
|
||||
graph(ctx::Context, c::Conv2D, x) =
|
||||
mx.Convolution(x,
|
||||
kernel = size(c.filter, 1, 2),
|
||||
|
|
|
@ -1,25 +1,12 @@
|
|||
using Flux: batchone, unbatchone, rebatch
|
||||
|
||||
# MNet batches on last dimension
|
||||
rebatch_last(xs) = permutedims(xs, (2:ndims(xs)..., 1))
|
||||
rebatch_first(xs) = permutedims(xs, (ndims(xs), 1:ndims(xs)-1...))
|
||||
rebatch_first(xs::Tuple) = rebatch_first.(xs)
|
||||
|
||||
paramvalue(p) = rebatch_last(p)
|
||||
paramvalue(p::Flux.Param) = paramvalue(p.x)
|
||||
|
||||
# Basically a kludge to make Affine work
|
||||
# Hopefully will go away with more inference
|
||||
type AlterParam
|
||||
param::Flux.Param
|
||||
strip::Bool
|
||||
rebatch::Bool
|
||||
param
|
||||
load
|
||||
store
|
||||
end
|
||||
|
||||
function paramvalue(p::AlterParam)
|
||||
val = p.rebatch ? paramvalue(p.param) : p.param.x
|
||||
p.strip ? squeeze(val, 1) : val
|
||||
end
|
||||
Base.size(p::AlterParam) = size(p.load(p.param.x))
|
||||
|
||||
type Graph
|
||||
output
|
||||
|
@ -28,34 +15,32 @@ type Graph
|
|||
end
|
||||
|
||||
function mxparams(g::Graph)
|
||||
params = Dict{Symbol,mx.NDArray}()
|
||||
params = Dict{Symbol,MXArray}()
|
||||
for (name, param) in g.params
|
||||
params[name] = mx.zeros(size(paramvalue(param)))
|
||||
params[name] = MXArray(size(param))
|
||||
end
|
||||
return params
|
||||
end
|
||||
|
||||
function loadparams!(g::Graph, args)
|
||||
for (id, param) in g.params
|
||||
haskey(args, id) && copy!(args[id], paramvalue(param))
|
||||
function copyargs!(as, bs)
|
||||
for id in intersect(keys(as), keys(bs))
|
||||
copy!(as[id], bs[id])
|
||||
end
|
||||
end
|
||||
|
||||
function storeparams!(g::Graph, args)
|
||||
for (id, param) in g.params
|
||||
haskey(args, id) && copy!(param.x, rebatch_first(copy(args[id])))
|
||||
end
|
||||
end
|
||||
ndparams(d::Dict{Symbol,MXArray}) = Dict(k => v.data for (k, v) in d)
|
||||
|
||||
type Model <: Flux.Model
|
||||
model::Any
|
||||
graph::Graph
|
||||
grads::Dict{Symbol,Any}
|
||||
args::Dict{Symbol,MXArray}
|
||||
grads::Dict{Symbol,MXArray}
|
||||
outs::Vector{MXArray}
|
||||
exec::mx.Executor
|
||||
end
|
||||
|
||||
loadparams!(model::Model) = loadparams!(model.graph, model.exec.arg_dict)
|
||||
storeparams!(model::Model) = storeparams!(model.graph, model.exec.arg_dict)
|
||||
loadparams!(model::Model) = copyargs!(model.args, model.graph.params)
|
||||
storeparams!(model::Model) = copyargs!(model.graph.params, model.args)
|
||||
|
||||
mxgroup(x) = x
|
||||
mxgroup(x::Tuple) = mx.Group(mxgroup.(x)...)
|
||||
|
@ -64,35 +49,34 @@ mxungroup(x::Tuple, outs) = map(x -> mxungroup(x, outs), x)
|
|||
|
||||
function mxnet(model::Flux.Model, input)
|
||||
graph = tograph(model, mx.Variable(:input))
|
||||
args = merge(mxparams(graph), Dict(:input => mx.zeros(input)))
|
||||
grads = merge(mxparams(graph), Dict(:input => mx.zeros(input)))
|
||||
model = @mxerr graph.stacks Model(model, graph, grads,
|
||||
mx.bind(mxgroup(graph.output), args = args,
|
||||
args_grad = grads,
|
||||
grad_req = mx.GRAD_ADD))
|
||||
args = merge(mxparams(graph), Dict(:input => MXArray(input)))
|
||||
grads = merge(mxparams(graph), Dict(:input => MXArray(input)))
|
||||
exec = @mxerr graph.stacks mx.bind(mxgroup(graph.output),
|
||||
args = ndparams(args),
|
||||
args_grad = ndparams(grads),
|
||||
grad_req = mx.GRAD_ADD)
|
||||
model = Model(model, graph, args, grads, MXArray.(exec.outputs), exec)
|
||||
loadparams!(model)
|
||||
return model
|
||||
end
|
||||
|
||||
function runmodel(model::Model, input)
|
||||
copy!(model.exec.arg_dict[:input], input)
|
||||
copy!(model.args[:input], input)
|
||||
mx.forward(model.exec, is_train = true)
|
||||
mxungroup(model.graph.output, copy(model.exec.outputs))
|
||||
mxungroup(model.graph.output, copy(model.outs))
|
||||
end
|
||||
|
||||
(m::Model)(x::Batch) = rebatch(rebatch_first(runmodel(m, rebatch_last(rawbatch(x)))))
|
||||
(m::Model)(x::Batch) = rebatch(runmodel(m, rawbatch(x)))
|
||||
|
||||
(m::Model)(x) = unbatchone(m(batchone(x)))
|
||||
|
||||
tond(xs::AArray) = copy!(mx.zeros(size(xs)), xs)
|
||||
|
||||
function runback!(model::Model, Δ)
|
||||
model.grads[:input][:] = 0
|
||||
mx.backward(model.exec, tond(Δ))
|
||||
mx.backward(model.exec, MXArray(Δ).data)
|
||||
copy(model.grads[:input])
|
||||
end
|
||||
|
||||
Flux.back!(m::Model, Δ::Batch, x) = rebatch(rebatch_first(runback!(m, rebatch_last(rawbatch(Δ)))))
|
||||
Flux.back!(m::Model, Δ::Batch, x) = rebatch(runback!(m, rawbatch(Δ)))
|
||||
|
||||
Flux.back!(m::Model, Δ, x) = first(Flux.back!(m, batchone(Δ), x))
|
||||
|
||||
|
@ -126,6 +110,6 @@ function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, cont
|
|||
model = rewrite_softmax(model, label)
|
||||
graph = tograph(model, mx.Variable(input), feedforward=true)
|
||||
ff = mx.FeedForward(graph.output, context = context)
|
||||
isempty(graph.params) || (ff.arg_params = mxparams(graph))
|
||||
isempty(graph.params) || (ff.arg_params = ndparams(mxparams(graph)))
|
||||
return ff
|
||||
end
|
||||
|
|
|
@ -6,11 +6,11 @@ Flux.loadmx()
|
|||
xs = rand(20)
|
||||
d = Affine(20, 10)
|
||||
|
||||
dm = mxnet(d, (20, 1))
|
||||
dm = mxnet(d, (1, 20))
|
||||
@test d(xs) ≈ dm(xs)
|
||||
|
||||
m = Multi(20, 15)
|
||||
mm = mxnet(m, (20, 1))
|
||||
mm = mxnet(m, (1, 20))
|
||||
@test all(isapprox.(mm(xs), m(xs)))
|
||||
|
||||
@testset "Backward Pass" begin
|
||||
|
|
Loading…
Reference in New Issue