diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl index b25636d0..8b9e4b6d 100644 --- a/src/backend/mxnet/model.jl +++ b/src/backend/mxnet/model.jl @@ -34,6 +34,18 @@ function mxparams(g::Graph) return params end +function loadparams!(g::Graph, args) + for (id, param) in g.params + haskey(args, id) && copy!(args[id], paramvalue(param)) + end +end + +function storeparams!(g::Graph, args) + for (id, param) in g.params + haskey(args, id) && copy!(param.x, rebatch_first(copy(args[id]))) + end +end + type Model <: Flux.Model model::Any graph::Graph @@ -41,12 +53,8 @@ type Model <: Flux.Model exec::mx.Executor end -function loadparams!(model::Model) - for (name, arr) in model.exec.arg_dict - haskey(model.graph.params, name) && copy!(arr, paramvalue(model.graph.params[name])) - end - return model -end +loadparams!(model::Model) = loadparams!(model.graph, model.exec.arg_dict) +storeparams!(model::Model) = storeparams!(model.graph, model.exec.arg_dict) function mxnet(model::Flux.Model, input) graph = tograph(model, mx.Variable(:input)) @@ -89,6 +97,7 @@ function Flux.update!(model::Model, η) grad[:] = 0 end end + storeparams!(model) return model end