tweak param loading
This commit is contained in:
parent
2f2ff0b03b
commit
2d77220d60
@ -38,8 +38,8 @@ graph(ctx::Context, d::Affine, x) =
|
|||||||
register(ctx,
|
register(ctx,
|
||||||
mx.FullyConnected(x,
|
mx.FullyConnected(x,
|
||||||
num_hidden = size(d.W.x, 2),
|
num_hidden = size(d.W.x, 2),
|
||||||
weight = var(ctx, d.W, size(d.W)),
|
weight = var(ctx, AlterParam(d.W, false, false)),
|
||||||
bias = var(ctx, d.b, size(d.b, 2))))
|
bias = var(ctx, AlterParam(d.b, true, false))))
|
||||||
|
|
||||||
# TODO: use actual params}
|
# TODO: use actual params}
|
||||||
graph(ctx::Context, c::Conv2D, x) =
|
graph(ctx::Context, c::Conv2D, x) =
|
||||||
@ -61,9 +61,9 @@ end
|
|||||||
|
|
||||||
register(ctx::Context, node) = node
|
register(ctx::Context, node) = node
|
||||||
|
|
||||||
function var(ctx::Context, p::Flux.Param, size = nothing)
|
function var(ctx::Context, p)
|
||||||
id = gensym()
|
id = gensym()
|
||||||
ctx[:params][id] = size == nothing ? rebatch_last(p.x) : reshape(p.x, size...)
|
ctx[:params][id] = p
|
||||||
return mx.Variable(id)
|
return mx.Variable(id)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -1,11 +1,39 @@
|
|||||||
using Flux: batchone, rebatch
|
using Flux: batchone, rebatch
|
||||||
|
|
||||||
|
# MNet batches on last dimension
|
||||||
|
rebatch_last(xs) = permutedims(xs, (2:ndims(xs)..., 1))
|
||||||
|
rebatch_first(xs) = permutedims(xs, (ndims(xs), 1:ndims(xs)-1...))
|
||||||
|
|
||||||
|
paramvalue(p) = rebatch_last(p)
|
||||||
|
paramvalue(p::Flux.Param) = paramvalue(p.x)
|
||||||
|
|
||||||
|
# Basically a kludge to make Affine work
|
||||||
|
# Hopefully will go away with more inference
|
||||||
|
type AlterParam
|
||||||
|
param::Flux.Param
|
||||||
|
strip::Bool
|
||||||
|
rebatch::Bool
|
||||||
|
end
|
||||||
|
|
||||||
|
function paramvalue(p::AlterParam)
|
||||||
|
val = p.rebatch ? paramvalue(p.param) : p.param.x
|
||||||
|
p.strip ? squeeze(val, 1) : val
|
||||||
|
end
|
||||||
|
|
||||||
type Graph
|
type Graph
|
||||||
node::mx.SymbolicNode
|
node::mx.SymbolicNode
|
||||||
params::Dict{Symbol,Any}
|
params::Dict{Symbol,Any}
|
||||||
stacks::Dict{Any,Any}
|
stacks::Dict{Any,Any}
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function mxparams(g::Graph)
|
||||||
|
params = Dict{Symbol,mx.NDArray}()
|
||||||
|
for (name, param) in g.params
|
||||||
|
params[name] = mx.zeros(size(paramvalue(param)))
|
||||||
|
end
|
||||||
|
return params
|
||||||
|
end
|
||||||
|
|
||||||
type Model <: Flux.Model
|
type Model <: Flux.Model
|
||||||
model::Any
|
model::Any
|
||||||
graph::Graph
|
graph::Graph
|
||||||
@ -13,37 +41,17 @@ type Model <: Flux.Model
|
|||||||
exec::mx.Executor
|
exec::mx.Executor
|
||||||
end
|
end
|
||||||
|
|
||||||
tond(xs::AArray) = copy!(mx.zeros(size(xs)), xs)
|
|
||||||
|
|
||||||
ndzero!(xs::mx.NDArray) = copy!(xs, mx.zeros(size(xs)))
|
|
||||||
|
|
||||||
function mxargs(args)
|
|
||||||
isempty(args) && return Dict{Symbol,mx.NDArray}()
|
|
||||||
map(args) do kv
|
|
||||||
arg, value = kv
|
|
||||||
arg => tond(value)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function mxgrads(mxargs)
|
|
||||||
isempty(mxargs) && return Dict{Symbol,mx.NDArray}()
|
|
||||||
map(mxargs) do kv
|
|
||||||
arg, value = kv
|
|
||||||
arg => mx.zeros(size(value))
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function loadparams!(model::Model)
|
function loadparams!(model::Model)
|
||||||
for (name, arr) in model.exec.arg_dict
|
for (name, arr) in model.exec.arg_dict
|
||||||
haskey(model.graph.params, name) && copy!(arr, model.graph.params[name])
|
haskey(model.graph.params, name) && copy!(arr, paramvalue(model.graph.params[name]))
|
||||||
end
|
end
|
||||||
return model
|
return model
|
||||||
end
|
end
|
||||||
|
|
||||||
function mxnet(model::Flux.Model, input)
|
function mxnet(model::Flux.Model, input)
|
||||||
graph = tograph(model, mx.Variable(:input))
|
graph = tograph(model, mx.Variable(:input))
|
||||||
args = merge(mxargs(graph.params), Dict(:input => mx.zeros(input)))
|
args = merge(mxparams(graph), Dict(:input => mx.zeros(input)))
|
||||||
grads = mxgrads(args)
|
grads = mxparams(graph)
|
||||||
model = @mxerr graph.stacks Model(model, graph, grads,
|
model = @mxerr graph.stacks Model(model, graph, grads,
|
||||||
mx.bind(graph.node, args = args,
|
mx.bind(graph.node, args = args,
|
||||||
args_grad = grads,
|
args_grad = grads,
|
||||||
@ -52,10 +60,6 @@ function mxnet(model::Flux.Model, input)
|
|||||||
return model
|
return model
|
||||||
end
|
end
|
||||||
|
|
||||||
# MNet batches on last dimension
|
|
||||||
rebatch_last(xs) = permutedims(xs, (2:ndims(xs)..., 1))
|
|
||||||
rebatch_first(xs) = permutedims(xs, (ndims(xs), 1:ndims(xs)-1...))
|
|
||||||
|
|
||||||
function runmodel(model::Model, input)
|
function runmodel(model::Model, input)
|
||||||
copy!(model.exec.arg_dict[:input], input)
|
copy!(model.exec.arg_dict[:input], input)
|
||||||
mx.forward(model.exec, is_train = true)
|
mx.forward(model.exec, is_train = true)
|
||||||
@ -101,6 +105,6 @@ function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, cont
|
|||||||
model = rewrite_softmax(model, label)
|
model = rewrite_softmax(model, label)
|
||||||
graph = tograph(model, mx.Variable(input), feedforward=true)
|
graph = tograph(model, mx.Variable(input), feedforward=true)
|
||||||
ff = mx.FeedForward(graph.node, context = context)
|
ff = mx.FeedForward(graph.node, context = context)
|
||||||
isempty(graph.params) || (ff.arg_params = mxargs(graph.params))
|
isempty(graph.params) || (ff.arg_params = mxparams(graph))
|
||||||
return ff
|
return ff
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user