tweak param loading
This commit is contained in:
parent
2f2ff0b03b
commit
2d77220d60
@ -38,8 +38,8 @@ graph(ctx::Context, d::Affine, x) =
|
||||
register(ctx,
|
||||
mx.FullyConnected(x,
|
||||
num_hidden = size(d.W.x, 2),
|
||||
weight = var(ctx, d.W, size(d.W)),
|
||||
bias = var(ctx, d.b, size(d.b, 2))))
|
||||
weight = var(ctx, AlterParam(d.W, false, false)),
|
||||
bias = var(ctx, AlterParam(d.b, true, false))))
|
||||
|
||||
# TODO: use actual params}
|
||||
graph(ctx::Context, c::Conv2D, x) =
|
||||
@ -61,9 +61,9 @@ end
|
||||
|
||||
register(ctx::Context, node) = node
|
||||
|
||||
function var(ctx::Context, p::Flux.Param, size = nothing)
|
||||
function var(ctx::Context, p)
|
||||
id = gensym()
|
||||
ctx[:params][id] = size == nothing ? rebatch_last(p.x) : reshape(p.x, size...)
|
||||
ctx[:params][id] = p
|
||||
return mx.Variable(id)
|
||||
end
|
||||
|
||||
|
@ -1,11 +1,39 @@
|
||||
using Flux: batchone, rebatch
|
||||
|
||||
# MNet batches on last dimension
|
||||
rebatch_last(xs) = permutedims(xs, (2:ndims(xs)..., 1))
|
||||
rebatch_first(xs) = permutedims(xs, (ndims(xs), 1:ndims(xs)-1...))
|
||||
|
||||
paramvalue(p) = rebatch_last(p)
|
||||
paramvalue(p::Flux.Param) = paramvalue(p.x)
|
||||
|
||||
# Basically a kludge to make Affine work
|
||||
# Hopefully will go away with more inference
|
||||
type AlterParam
|
||||
param::Flux.Param
|
||||
strip::Bool
|
||||
rebatch::Bool
|
||||
end
|
||||
|
||||
function paramvalue(p::AlterParam)
|
||||
val = p.rebatch ? paramvalue(p.param) : p.param.x
|
||||
p.strip ? squeeze(val, 1) : val
|
||||
end
|
||||
|
||||
type Graph
|
||||
node::mx.SymbolicNode
|
||||
params::Dict{Symbol,Any}
|
||||
stacks::Dict{Any,Any}
|
||||
end
|
||||
|
||||
function mxparams(g::Graph)
|
||||
params = Dict{Symbol,mx.NDArray}()
|
||||
for (name, param) in g.params
|
||||
params[name] = mx.zeros(size(paramvalue(param)))
|
||||
end
|
||||
return params
|
||||
end
|
||||
|
||||
type Model <: Flux.Model
|
||||
model::Any
|
||||
graph::Graph
|
||||
@ -13,37 +41,17 @@ type Model <: Flux.Model
|
||||
exec::mx.Executor
|
||||
end
|
||||
|
||||
tond(xs::AArray) = copy!(mx.zeros(size(xs)), xs)
|
||||
|
||||
ndzero!(xs::mx.NDArray) = copy!(xs, mx.zeros(size(xs)))
|
||||
|
||||
function mxargs(args)
|
||||
isempty(args) && return Dict{Symbol,mx.NDArray}()
|
||||
map(args) do kv
|
||||
arg, value = kv
|
||||
arg => tond(value)
|
||||
end
|
||||
end
|
||||
|
||||
function mxgrads(mxargs)
|
||||
isempty(mxargs) && return Dict{Symbol,mx.NDArray}()
|
||||
map(mxargs) do kv
|
||||
arg, value = kv
|
||||
arg => mx.zeros(size(value))
|
||||
end
|
||||
end
|
||||
|
||||
function loadparams!(model::Model)
|
||||
for (name, arr) in model.exec.arg_dict
|
||||
haskey(model.graph.params, name) && copy!(arr, model.graph.params[name])
|
||||
haskey(model.graph.params, name) && copy!(arr, paramvalue(model.graph.params[name]))
|
||||
end
|
||||
return model
|
||||
end
|
||||
|
||||
function mxnet(model::Flux.Model, input)
|
||||
graph = tograph(model, mx.Variable(:input))
|
||||
args = merge(mxargs(graph.params), Dict(:input => mx.zeros(input)))
|
||||
grads = mxgrads(args)
|
||||
args = merge(mxparams(graph), Dict(:input => mx.zeros(input)))
|
||||
grads = mxparams(graph)
|
||||
model = @mxerr graph.stacks Model(model, graph, grads,
|
||||
mx.bind(graph.node, args = args,
|
||||
args_grad = grads,
|
||||
@ -52,10 +60,6 @@ function mxnet(model::Flux.Model, input)
|
||||
return model
|
||||
end
|
||||
|
||||
# MNet batches on last dimension
|
||||
rebatch_last(xs) = permutedims(xs, (2:ndims(xs)..., 1))
|
||||
rebatch_first(xs) = permutedims(xs, (ndims(xs), 1:ndims(xs)-1...))
|
||||
|
||||
function runmodel(model::Model, input)
|
||||
copy!(model.exec.arg_dict[:input], input)
|
||||
mx.forward(model.exec, is_train = true)
|
||||
@ -101,6 +105,6 @@ function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, cont
|
||||
model = rewrite_softmax(model, label)
|
||||
graph = tograph(model, mx.Variable(input), feedforward=true)
|
||||
ff = mx.FeedForward(graph.node, context = context)
|
||||
isempty(graph.params) || (ff.arg_params = mxargs(graph.params))
|
||||
isempty(graph.params) || (ff.arg_params = mxparams(graph))
|
||||
return ff
|
||||
end
|
||||
|
Loading…
Reference in New Issue
Block a user