parameter storage

This commit is contained in:
Mike J Innes 2017-02-23 21:42:34 +00:00
parent a4812579e9
commit 06fd5adddc

View File

@ -34,6 +34,18 @@ function mxparams(g::Graph)
return params
end
function loadparams!(g::Graph, args)
for (id, param) in g.params
haskey(args, id) && copy!(args[id], paramvalue(param))
end
end
function storeparams!(g::Graph, args)
for (id, param) in g.params
haskey(args, id) && copy!(param.x, rebatch_first(copy(args[id])))
end
end
type Model <: Flux.Model
model::Any
graph::Graph
@ -41,12 +53,8 @@ type Model <: Flux.Model
exec::mx.Executor
end
function loadparams!(model::Model)
for (name, arr) in model.exec.arg_dict
haskey(model.graph.params, name) && copy!(arr, paramvalue(model.graph.params[name]))
end
return model
end
loadparams!(model::Model) = loadparams!(model.graph, model.exec.arg_dict)
storeparams!(model::Model) = storeparams!(model.graph, model.exec.arg_dict)
function mxnet(model::Flux.Model, input)
graph = tograph(model, mx.Variable(:input))
@ -89,6 +97,7 @@ function Flux.update!(model::Model, η)
grad[:] = 0
end
end
storeparams!(model)
return model
end