parameter storage
This commit is contained in:
parent
a4812579e9
commit
06fd5adddc
@ -34,6 +34,18 @@ function mxparams(g::Graph)
|
|||||||
return params
|
return params
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function loadparams!(g::Graph, args)
|
||||||
|
for (id, param) in g.params
|
||||||
|
haskey(args, id) && copy!(args[id], paramvalue(param))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
function storeparams!(g::Graph, args)
|
||||||
|
for (id, param) in g.params
|
||||||
|
haskey(args, id) && copy!(param.x, rebatch_first(copy(args[id])))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
type Model <: Flux.Model
|
type Model <: Flux.Model
|
||||||
model::Any
|
model::Any
|
||||||
graph::Graph
|
graph::Graph
|
||||||
@ -41,12 +53,8 @@ type Model <: Flux.Model
|
|||||||
exec::mx.Executor
|
exec::mx.Executor
|
||||||
end
|
end
|
||||||
|
|
||||||
function loadparams!(model::Model)
|
loadparams!(model::Model) = loadparams!(model.graph, model.exec.arg_dict)
|
||||||
for (name, arr) in model.exec.arg_dict
|
storeparams!(model::Model) = storeparams!(model.graph, model.exec.arg_dict)
|
||||||
haskey(model.graph.params, name) && copy!(arr, paramvalue(model.graph.params[name]))
|
|
||||||
end
|
|
||||||
return model
|
|
||||||
end
|
|
||||||
|
|
||||||
function mxnet(model::Flux.Model, input)
|
function mxnet(model::Flux.Model, input)
|
||||||
graph = tograph(model, mx.Variable(:input))
|
graph = tograph(model, mx.Variable(:input))
|
||||||
@ -89,6 +97,7 @@ function Flux.update!(model::Model, η)
|
|||||||
grad[:] = 0
|
grad[:] = 0
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
storeparams!(model)
|
||||||
return model
|
return model
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user