parameter storage
This commit is contained in:
parent
a4812579e9
commit
06fd5adddc
@ -34,6 +34,18 @@ function mxparams(g::Graph)
|
||||
return params
|
||||
end
|
||||
|
||||
function loadparams!(g::Graph, args)
|
||||
for (id, param) in g.params
|
||||
haskey(args, id) && copy!(args[id], paramvalue(param))
|
||||
end
|
||||
end
|
||||
|
||||
function storeparams!(g::Graph, args)
|
||||
for (id, param) in g.params
|
||||
haskey(args, id) && copy!(param.x, rebatch_first(copy(args[id])))
|
||||
end
|
||||
end
|
||||
|
||||
type Model <: Flux.Model
|
||||
model::Any
|
||||
graph::Graph
|
||||
@ -41,12 +53,8 @@ type Model <: Flux.Model
|
||||
exec::mx.Executor
|
||||
end
|
||||
|
||||
function loadparams!(model::Model)
|
||||
for (name, arr) in model.exec.arg_dict
|
||||
haskey(model.graph.params, name) && copy!(arr, paramvalue(model.graph.params[name]))
|
||||
end
|
||||
return model
|
||||
end
|
||||
loadparams!(model::Model) = loadparams!(model.graph, model.exec.arg_dict)
|
||||
storeparams!(model::Model) = storeparams!(model.graph, model.exec.arg_dict)
|
||||
|
||||
function mxnet(model::Flux.Model, input)
|
||||
graph = tograph(model, mx.Variable(:input))
|
||||
@ -89,6 +97,7 @@ function Flux.update!(model::Model, η)
|
||||
grad[:] = 0
|
||||
end
|
||||
end
|
||||
storeparams!(model)
|
||||
return model
|
||||
end
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user