From 06fd5adddc8b952072caa7c515dd54e48539ce2b Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Thu, 23 Feb 2017 21:42:34 +0000 Subject: [PATCH] parameter storage --- src/backend/mxnet/model.jl | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl index b25636d0..8b9e4b6d 100644 --- a/src/backend/mxnet/model.jl +++ b/src/backend/mxnet/model.jl @@ -34,6 +34,18 @@ function mxparams(g::Graph) return params end +function loadparams!(g::Graph, args) + for (id, param) in g.params + haskey(args, id) && copy!(args[id], paramvalue(param)) + end +end + +function storeparams!(g::Graph, args) + for (id, param) in g.params + haskey(args, id) && copy!(param.x, rebatch_first(copy(args[id]))) + end +end + type Model <: Flux.Model model::Any graph::Graph @@ -41,12 +53,8 @@ type Model <: Flux.Model exec::mx.Executor end -function loadparams!(model::Model) - for (name, arr) in model.exec.arg_dict - haskey(model.graph.params, name) && copy!(arr, paramvalue(model.graph.params[name])) - end - return model -end +loadparams!(model::Model) = loadparams!(model.graph, model.exec.arg_dict) +storeparams!(model::Model) = storeparams!(model.graph, model.exec.arg_dict) function mxnet(model::Flux.Model, input) graph = tograph(model, mx.Variable(:input)) @@ -89,6 +97,7 @@ function Flux.update!(model::Model, η) grad[:] = 0 end end + storeparams!(model) return model end