support constant arrays in MXNet
This commit is contained in:
parent
a2db4b5319
commit
5be9ce45d8
@ -73,6 +73,8 @@ end
|
||||
|
||||
graph{T<:AArray}(ctx::Context, p::Constant{Flux.Param{T}}) = var(ctx, p.value)
|
||||
|
||||
graph{T<:AArray}(ctx::Context, p::Constant{T}) = var(ctx, p.value)
|
||||
|
||||
graph(ctx::Context, p::Constant) = p.value
|
||||
|
||||
function graph(ctx::Context, model, args...)
|
||||
|
@ -22,9 +22,9 @@ struct Graph
|
||||
stacks::Dict{Any,Any}
|
||||
end
|
||||
|
||||
function mxparams(g::Graph)
|
||||
function mxparams(ps)
|
||||
params = Dict{Symbol,MXArray}()
|
||||
for (name, param) in g.params
|
||||
for (name, param) in ps
|
||||
params[name] = MXArray(size(param))
|
||||
end
|
||||
return params
|
||||
@ -52,8 +52,9 @@ dictt(xs, ys) = Dict(zip(collectt(xs), collectt(ys)))
|
||||
|
||||
function executor(graph::Graph, input...)
|
||||
shapecheckt(graph.input, input)
|
||||
args = merge(mxparams(graph), dictt(graph.input, mapt(d->MXArray(size(d)), input)))
|
||||
grads = merge(mxparams(graph), dictt(graph.input, mapt(d->MXArray(size(d)), input)))
|
||||
args = merge(mxparams(graph.params), dictt(graph.input, mapt(d->MXArray(size(d)), input)))
|
||||
grads = filter((a, b) -> b isa Flux.Param, graph.params)
|
||||
grads = merge(mxparams(grads), dictt(graph.input, mapt(d->MXArray(size(d)), input)))
|
||||
exec = mx.bind(mxgroup(graph.output),
|
||||
args = ndparams(args),
|
||||
args_grad = ndparams(grads),
|
||||
@ -77,6 +78,7 @@ end
|
||||
|
||||
function Flux.update!(exec::Exec, η)
|
||||
for (arg, grad) in zip(exec.exec.arg_arrays, exec.exec.grad_arrays)
|
||||
grad == nothing && continue
|
||||
mx.@nd_as_jl rw = (arg, grad) begin
|
||||
arg .-= grad .* η
|
||||
grad[:] = 0
|
||||
@ -145,6 +147,6 @@ function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, cont
|
||||
model = rewrite_softmax(model, label)
|
||||
graph = tograph(model, input, feedforward=true)
|
||||
ff = mx.FeedForward(graph.output, context = context)
|
||||
isempty(graph.params) || (ff.arg_params = ndparams(mxparams(graph)))
|
||||
isempty(graph.params) || (ff.arg_params = ndparams(mxparams(graph.params)))
|
||||
return ff
|
||||
end
|
||||
|
Loading…
Reference in New Issue
Block a user