mxnet multi output

This commit is contained in:
Mike J Innes 2017-03-06 17:20:15 +00:00
parent d9910070b4
commit 5d919175fc
2 changed files with 16 additions and 6 deletions

View File

@ -1,8 +1,9 @@
using Flux: batchone, rebatch
using Flux: batchone, unbatchone, rebatch
# MNet batches on last dimension
rebatch_last(xs) = permutedims(xs, (2:ndims(xs)..., 1))
rebatch_first(xs) = permutedims(xs, (ndims(xs), 1:ndims(xs)-1...))
rebatch_first(xs::Tuple) = rebatch_first.(xs)
paramvalue(p) = rebatch_last(p)
paramvalue(p::Flux.Param) = paramvalue(p.x)
@ -21,7 +22,7 @@ function paramvalue(p::AlterParam)
end
type Graph
node::mx.SymbolicNode
output
params::Dict{Symbol,Any}
stacks::Dict{Any,Any}
end
@ -56,12 +57,17 @@ end
loadparams!(model::Model) = loadparams!(model.graph, model.exec.arg_dict)
storeparams!(model::Model) = storeparams!(model.graph, model.exec.arg_dict)
mxgroup(x) = x
mxgroup(x::Tuple) = mx.Group(mxgroup.(x)...)
mxungroup(x, outs) = copy(shift!(outs))
mxungroup(x::Tuple, outs) = map(x -> mxungroup(x, outs), x)
function mxnet(model::Flux.Model, input)
graph = tograph(model, mx.Variable(:input))
args = merge(mxparams(graph), Dict(:input => mx.zeros(input)))
grads = merge(mxparams(graph), Dict(:input => mx.zeros(input)))
model = @mxerr graph.stacks Model(model, graph, grads,
mx.bind(graph.node, args = args,
mx.bind(mxgroup(graph.output), args = args,
args_grad = grads,
grad_req = mx.GRAD_ADD))
loadparams!(model)
@ -71,12 +77,12 @@ end
function runmodel(model::Model, input)
copy!(model.exec.arg_dict[:input], input)
mx.forward(model.exec, is_train = true)
copy(model.exec.outputs[1])
mxungroup(model.graph.output, copy(model.exec.outputs))
end
(m::Model)(x::Batch) = rebatch(rebatch_first(runmodel(m, rebatch_last(rawbatch(x)))))
(m::Model)(x) = first(m(batchone(x)))
(m::Model)(x) = unbatchone(m(batchone(x)))
tond(xs::AArray) = copy!(mx.zeros(size(xs)), xs)
@ -119,7 +125,7 @@ end
function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, context = mx.cpu())
model = rewrite_softmax(model, label)
graph = tograph(model, mx.Variable(input), feedforward=true)
ff = mx.FeedForward(graph.node, context = context)
ff = mx.FeedForward(graph.output, context = context)
isempty(graph.params) || (ff.arg_params = mxparams(graph))
return ff
end

View File

@ -9,6 +9,10 @@ d = Affine(20, 10)
dm = mxnet(d, (20, 1))
@test d(xs) dm(xs)
m = Multi(20, 15)
mm = mxnet(m, (20, 1))
@test all(isapprox.(mm(xs), m(xs)))
@testset "Backward Pass" begin
d = deepcopy(d)
@test dm(xs) d(xs)