tweak param loading

This commit is contained in:
Mike J Innes 2017-02-23 18:48:46 +00:00
parent 2f2ff0b03b
commit 2d77220d60
2 changed files with 36 additions and 32 deletions

View File

@ -38,8 +38,8 @@ graph(ctx::Context, d::Affine, x) =
register(ctx,
mx.FullyConnected(x,
num_hidden = size(d.W.x, 2),
weight = var(ctx, d.W, size(d.W)),
bias = var(ctx, d.b, size(d.b, 2))))
weight = var(ctx, AlterParam(d.W, false, false)),
bias = var(ctx, AlterParam(d.b, true, false))))
# TODO: use actual params}
graph(ctx::Context, c::Conv2D, x) =
@ -61,9 +61,9 @@ end
register(ctx::Context, node) = node
function var(ctx::Context, p::Flux.Param, size = nothing)
function var(ctx::Context, p)
id = gensym()
ctx[:params][id] = size == nothing ? rebatch_last(p.x) : reshape(p.x, size...)
ctx[:params][id] = p
return mx.Variable(id)
end

View File

@ -1,11 +1,39 @@
using Flux: batchone, rebatch
# MNet batches on last dimension
rebatch_last(xs) = permutedims(xs, (2:ndims(xs)..., 1))
rebatch_first(xs) = permutedims(xs, (ndims(xs), 1:ndims(xs)-1...))
paramvalue(p) = rebatch_last(p)
paramvalue(p::Flux.Param) = paramvalue(p.x)
# Basically a kludge to make Affine work
# Hopefully will go away with more inference
type AlterParam
param::Flux.Param
strip::Bool
rebatch::Bool
end
function paramvalue(p::AlterParam)
val = p.rebatch ? paramvalue(p.param) : p.param.x
p.strip ? squeeze(val, 1) : val
end
type Graph
node::mx.SymbolicNode
params::Dict{Symbol,Any}
stacks::Dict{Any,Any}
end
function mxparams(g::Graph)
params = Dict{Symbol,mx.NDArray}()
for (name, param) in g.params
params[name] = mx.zeros(size(paramvalue(param)))
end
return params
end
type Model <: Flux.Model
model::Any
graph::Graph
@ -13,37 +41,17 @@ type Model <: Flux.Model
exec::mx.Executor
end
tond(xs::AArray) = copy!(mx.zeros(size(xs)), xs)
ndzero!(xs::mx.NDArray) = copy!(xs, mx.zeros(size(xs)))
function mxargs(args)
isempty(args) && return Dict{Symbol,mx.NDArray}()
map(args) do kv
arg, value = kv
arg => tond(value)
end
end
function mxgrads(mxargs)
isempty(mxargs) && return Dict{Symbol,mx.NDArray}()
map(mxargs) do kv
arg, value = kv
arg => mx.zeros(size(value))
end
end
function loadparams!(model::Model)
for (name, arr) in model.exec.arg_dict
haskey(model.graph.params, name) && copy!(arr, model.graph.params[name])
haskey(model.graph.params, name) && copy!(arr, paramvalue(model.graph.params[name]))
end
return model
end
function mxnet(model::Flux.Model, input)
graph = tograph(model, mx.Variable(:input))
args = merge(mxargs(graph.params), Dict(:input => mx.zeros(input)))
grads = mxgrads(args)
args = merge(mxparams(graph), Dict(:input => mx.zeros(input)))
grads = mxparams(graph)
model = @mxerr graph.stacks Model(model, graph, grads,
mx.bind(graph.node, args = args,
args_grad = grads,
@ -52,10 +60,6 @@ function mxnet(model::Flux.Model, input)
return model
end
# MNet batches on last dimension
rebatch_last(xs) = permutedims(xs, (2:ndims(xs)..., 1))
rebatch_first(xs) = permutedims(xs, (ndims(xs), 1:ndims(xs)-1...))
function runmodel(model::Model, input)
copy!(model.exec.arg_dict[:input], input)
mx.forward(model.exec, is_train = true)
@ -101,6 +105,6 @@ function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, cont
model = rewrite_softmax(model, label)
graph = tograph(model, mx.Variable(input), feedforward=true)
ff = mx.FeedForward(graph.node, context = context)
isempty(graph.params) || (ff.arg_params = mxargs(graph.params))
isempty(graph.params) || (ff.arg_params = mxparams(graph))
return ff
end