mxnet multi output
This commit is contained in:
parent
d9910070b4
commit
5d919175fc
|
@ -1,8 +1,9 @@
|
|||
using Flux: batchone, rebatch
|
||||
using Flux: batchone, unbatchone, rebatch
|
||||
|
||||
# MNet batches on last dimension
|
||||
rebatch_last(xs) = permutedims(xs, (2:ndims(xs)..., 1))
|
||||
rebatch_first(xs) = permutedims(xs, (ndims(xs), 1:ndims(xs)-1...))
|
||||
rebatch_first(xs::Tuple) = rebatch_first.(xs)
|
||||
|
||||
paramvalue(p) = rebatch_last(p)
|
||||
paramvalue(p::Flux.Param) = paramvalue(p.x)
|
||||
|
@ -21,7 +22,7 @@ function paramvalue(p::AlterParam)
|
|||
end
|
||||
|
||||
type Graph
|
||||
node::mx.SymbolicNode
|
||||
output
|
||||
params::Dict{Symbol,Any}
|
||||
stacks::Dict{Any,Any}
|
||||
end
|
||||
|
@ -56,12 +57,17 @@ end
|
|||
loadparams!(model::Model) = loadparams!(model.graph, model.exec.arg_dict)
|
||||
storeparams!(model::Model) = storeparams!(model.graph, model.exec.arg_dict)
|
||||
|
||||
mxgroup(x) = x
|
||||
mxgroup(x::Tuple) = mx.Group(mxgroup.(x)...)
|
||||
mxungroup(x, outs) = copy(shift!(outs))
|
||||
mxungroup(x::Tuple, outs) = map(x -> mxungroup(x, outs), x)
|
||||
|
||||
function mxnet(model::Flux.Model, input)
|
||||
graph = tograph(model, mx.Variable(:input))
|
||||
args = merge(mxparams(graph), Dict(:input => mx.zeros(input)))
|
||||
grads = merge(mxparams(graph), Dict(:input => mx.zeros(input)))
|
||||
model = @mxerr graph.stacks Model(model, graph, grads,
|
||||
mx.bind(graph.node, args = args,
|
||||
mx.bind(mxgroup(graph.output), args = args,
|
||||
args_grad = grads,
|
||||
grad_req = mx.GRAD_ADD))
|
||||
loadparams!(model)
|
||||
|
@ -71,12 +77,12 @@ end
|
|||
function runmodel(model::Model, input)
|
||||
copy!(model.exec.arg_dict[:input], input)
|
||||
mx.forward(model.exec, is_train = true)
|
||||
copy(model.exec.outputs[1])
|
||||
mxungroup(model.graph.output, copy(model.exec.outputs))
|
||||
end
|
||||
|
||||
(m::Model)(x::Batch) = rebatch(rebatch_first(runmodel(m, rebatch_last(rawbatch(x)))))
|
||||
|
||||
(m::Model)(x) = first(m(batchone(x)))
|
||||
(m::Model)(x) = unbatchone(m(batchone(x)))
|
||||
|
||||
tond(xs::AArray) = copy!(mx.zeros(size(xs)), xs)
|
||||
|
||||
|
@ -119,7 +125,7 @@ end
|
|||
function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, context = mx.cpu())
|
||||
model = rewrite_softmax(model, label)
|
||||
graph = tograph(model, mx.Variable(input), feedforward=true)
|
||||
ff = mx.FeedForward(graph.node, context = context)
|
||||
ff = mx.FeedForward(graph.output, context = context)
|
||||
isempty(graph.params) || (ff.arg_params = mxparams(graph))
|
||||
return ff
|
||||
end
|
||||
|
|
|
@ -9,6 +9,10 @@ d = Affine(20, 10)
|
|||
dm = mxnet(d, (20, 1))
|
||||
@test d(xs) ≈ dm(xs)
|
||||
|
||||
m = Multi(20, 15)
|
||||
mm = mxnet(m, (20, 1))
|
||||
@test all(isapprox.(mm(xs), m(xs)))
|
||||
|
||||
@testset "Backward Pass" begin
|
||||
d′ = deepcopy(d)
|
||||
@test dm(xs) ≈ d(xs)
|
||||
|
|
Loading…
Reference in New Issue