support constant arrays in MXNet

This commit is contained in:
Mike J Innes 2017-05-04 15:09:18 +01:00
parent a2db4b5319
commit 5be9ce45d8
2 changed files with 9 additions and 5 deletions

View File

@ -73,6 +73,8 @@ end
graph{T<:AArray}(ctx::Context, p::Constant{Flux.Param{T}}) = var(ctx, p.value)
graph{T<:AArray}(ctx::Context, p::Constant{T}) = var(ctx, p.value)
graph(ctx::Context, p::Constant) = p.value
function graph(ctx::Context, model, args...)

View File

@ -22,9 +22,9 @@ struct Graph
stacks::Dict{Any,Any}
end
function mxparams(g::Graph)
function mxparams(ps)
params = Dict{Symbol,MXArray}()
for (name, param) in g.params
for (name, param) in ps
params[name] = MXArray(size(param))
end
return params
@ -52,8 +52,9 @@ dictt(xs, ys) = Dict(zip(collectt(xs), collectt(ys)))
function executor(graph::Graph, input...)
shapecheckt(graph.input, input)
args = merge(mxparams(graph), dictt(graph.input, mapt(d->MXArray(size(d)), input)))
grads = merge(mxparams(graph), dictt(graph.input, mapt(d->MXArray(size(d)), input)))
args = merge(mxparams(graph.params), dictt(graph.input, mapt(d->MXArray(size(d)), input)))
grads = filter((a, b) -> b isa Flux.Param, graph.params)
grads = merge(mxparams(grads), dictt(graph.input, mapt(d->MXArray(size(d)), input)))
exec = mx.bind(mxgroup(graph.output),
args = ndparams(args),
args_grad = ndparams(grads),
@ -77,6 +78,7 @@ end
function Flux.update!(exec::Exec, η)
for (arg, grad) in zip(exec.exec.arg_arrays, exec.exec.grad_arrays)
grad == nothing && continue
mx.@nd_as_jl rw = (arg, grad) begin
arg .-= grad .* η
grad[:] = 0
@ -145,6 +147,6 @@ function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, cont
model = rewrite_softmax(model, label)
graph = tograph(model, input, feedforward=true)
ff = mx.FeedForward(graph.output, context = context)
isempty(graph.params) || (ff.arg_params = ndparams(mxparams(graph)))
isempty(graph.params) || (ff.arg_params = ndparams(mxparams(graph.params)))
return ff
end