From 5be9ce45d8cba21cdcb96fa1eed0371508d9f44c Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Thu, 4 May 2017 15:09:18 +0100 Subject: [PATCH] support constant arrays in MXNet --- src/backend/mxnet/graph.jl | 2 ++ src/backend/mxnet/model.jl | 12 +++++++----- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/src/backend/mxnet/graph.jl b/src/backend/mxnet/graph.jl index 1d074a19..27cf9bc5 100644 --- a/src/backend/mxnet/graph.jl +++ b/src/backend/mxnet/graph.jl @@ -73,6 +73,8 @@ end graph{T<:AArray}(ctx::Context, p::Constant{Flux.Param{T}}) = var(ctx, p.value) +graph{T<:AArray}(ctx::Context, p::Constant{T}) = var(ctx, p.value) + graph(ctx::Context, p::Constant) = p.value function graph(ctx::Context, model, args...) diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl index cdeb1b2d..507c01c2 100644 --- a/src/backend/mxnet/model.jl +++ b/src/backend/mxnet/model.jl @@ -22,9 +22,9 @@ struct Graph stacks::Dict{Any,Any} end -function mxparams(g::Graph) +function mxparams(ps) params = Dict{Symbol,MXArray}() - for (name, param) in g.params + for (name, param) in ps params[name] = MXArray(size(param)) end return params @@ -52,8 +52,9 @@ dictt(xs, ys) = Dict(zip(collectt(xs), collectt(ys))) function executor(graph::Graph, input...) shapecheckt(graph.input, input) - args = merge(mxparams(graph), dictt(graph.input, mapt(d->MXArray(size(d)), input))) - grads = merge(mxparams(graph), dictt(graph.input, mapt(d->MXArray(size(d)), input))) + args = merge(mxparams(graph.params), dictt(graph.input, mapt(d->MXArray(size(d)), input))) + grads = filter((a, b) -> b isa Flux.Param, graph.params) + grads = merge(mxparams(grads), dictt(graph.input, mapt(d->MXArray(size(d)), input))) exec = mx.bind(mxgroup(graph.output), args = ndparams(args), args_grad = ndparams(grads), @@ -77,6 +78,7 @@ end function Flux.update!(exec::Exec, η) for (arg, grad) in zip(exec.exec.arg_arrays, exec.exec.grad_arrays) + grad == nothing && continue mx.@nd_as_jl rw = (arg, grad) begin arg .-= grad .* η grad[:] = 0 @@ -145,6 +147,6 @@ function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, cont model = rewrite_softmax(model, label) graph = tograph(model, input, feedforward=true) ff = mx.FeedForward(graph.output, context = context) - isempty(graph.params) || (ff.arg_params = ndparams(mxparams(graph))) + isempty(graph.params) || (ff.arg_params = ndparams(mxparams(graph.params))) return ff end