mxnet multi output
This commit is contained in:
parent
d9910070b4
commit
5d919175fc
@ -1,8 +1,9 @@
|
|||||||
using Flux: batchone, rebatch
|
using Flux: batchone, unbatchone, rebatch
|
||||||
|
|
||||||
# MNet batches on last dimension
|
# MNet batches on last dimension
|
||||||
rebatch_last(xs) = permutedims(xs, (2:ndims(xs)..., 1))
|
rebatch_last(xs) = permutedims(xs, (2:ndims(xs)..., 1))
|
||||||
rebatch_first(xs) = permutedims(xs, (ndims(xs), 1:ndims(xs)-1...))
|
rebatch_first(xs) = permutedims(xs, (ndims(xs), 1:ndims(xs)-1...))
|
||||||
|
rebatch_first(xs::Tuple) = rebatch_first.(xs)
|
||||||
|
|
||||||
paramvalue(p) = rebatch_last(p)
|
paramvalue(p) = rebatch_last(p)
|
||||||
paramvalue(p::Flux.Param) = paramvalue(p.x)
|
paramvalue(p::Flux.Param) = paramvalue(p.x)
|
||||||
@ -21,7 +22,7 @@ function paramvalue(p::AlterParam)
|
|||||||
end
|
end
|
||||||
|
|
||||||
type Graph
|
type Graph
|
||||||
node::mx.SymbolicNode
|
output
|
||||||
params::Dict{Symbol,Any}
|
params::Dict{Symbol,Any}
|
||||||
stacks::Dict{Any,Any}
|
stacks::Dict{Any,Any}
|
||||||
end
|
end
|
||||||
@ -56,12 +57,17 @@ end
|
|||||||
loadparams!(model::Model) = loadparams!(model.graph, model.exec.arg_dict)
|
loadparams!(model::Model) = loadparams!(model.graph, model.exec.arg_dict)
|
||||||
storeparams!(model::Model) = storeparams!(model.graph, model.exec.arg_dict)
|
storeparams!(model::Model) = storeparams!(model.graph, model.exec.arg_dict)
|
||||||
|
|
||||||
|
mxgroup(x) = x
|
||||||
|
mxgroup(x::Tuple) = mx.Group(mxgroup.(x)...)
|
||||||
|
mxungroup(x, outs) = copy(shift!(outs))
|
||||||
|
mxungroup(x::Tuple, outs) = map(x -> mxungroup(x, outs), x)
|
||||||
|
|
||||||
function mxnet(model::Flux.Model, input)
|
function mxnet(model::Flux.Model, input)
|
||||||
graph = tograph(model, mx.Variable(:input))
|
graph = tograph(model, mx.Variable(:input))
|
||||||
args = merge(mxparams(graph), Dict(:input => mx.zeros(input)))
|
args = merge(mxparams(graph), Dict(:input => mx.zeros(input)))
|
||||||
grads = merge(mxparams(graph), Dict(:input => mx.zeros(input)))
|
grads = merge(mxparams(graph), Dict(:input => mx.zeros(input)))
|
||||||
model = @mxerr graph.stacks Model(model, graph, grads,
|
model = @mxerr graph.stacks Model(model, graph, grads,
|
||||||
mx.bind(graph.node, args = args,
|
mx.bind(mxgroup(graph.output), args = args,
|
||||||
args_grad = grads,
|
args_grad = grads,
|
||||||
grad_req = mx.GRAD_ADD))
|
grad_req = mx.GRAD_ADD))
|
||||||
loadparams!(model)
|
loadparams!(model)
|
||||||
@ -71,12 +77,12 @@ end
|
|||||||
function runmodel(model::Model, input)
|
function runmodel(model::Model, input)
|
||||||
copy!(model.exec.arg_dict[:input], input)
|
copy!(model.exec.arg_dict[:input], input)
|
||||||
mx.forward(model.exec, is_train = true)
|
mx.forward(model.exec, is_train = true)
|
||||||
copy(model.exec.outputs[1])
|
mxungroup(model.graph.output, copy(model.exec.outputs))
|
||||||
end
|
end
|
||||||
|
|
||||||
(m::Model)(x::Batch) = rebatch(rebatch_first(runmodel(m, rebatch_last(rawbatch(x)))))
|
(m::Model)(x::Batch) = rebatch(rebatch_first(runmodel(m, rebatch_last(rawbatch(x)))))
|
||||||
|
|
||||||
(m::Model)(x) = first(m(batchone(x)))
|
(m::Model)(x) = unbatchone(m(batchone(x)))
|
||||||
|
|
||||||
tond(xs::AArray) = copy!(mx.zeros(size(xs)), xs)
|
tond(xs::AArray) = copy!(mx.zeros(size(xs)), xs)
|
||||||
|
|
||||||
@ -119,7 +125,7 @@ end
|
|||||||
function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, context = mx.cpu())
|
function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, context = mx.cpu())
|
||||||
model = rewrite_softmax(model, label)
|
model = rewrite_softmax(model, label)
|
||||||
graph = tograph(model, mx.Variable(input), feedforward=true)
|
graph = tograph(model, mx.Variable(input), feedforward=true)
|
||||||
ff = mx.FeedForward(graph.node, context = context)
|
ff = mx.FeedForward(graph.output, context = context)
|
||||||
isempty(graph.params) || (ff.arg_params = mxparams(graph))
|
isempty(graph.params) || (ff.arg_params = mxparams(graph))
|
||||||
return ff
|
return ff
|
||||||
end
|
end
|
||||||
|
@ -9,6 +9,10 @@ d = Affine(20, 10)
|
|||||||
dm = mxnet(d, (20, 1))
|
dm = mxnet(d, (20, 1))
|
||||||
@test d(xs) ≈ dm(xs)
|
@test d(xs) ≈ dm(xs)
|
||||||
|
|
||||||
|
m = Multi(20, 15)
|
||||||
|
mm = mxnet(m, (20, 1))
|
||||||
|
@test all(isapprox.(mm(xs), m(xs)))
|
||||||
|
|
||||||
@testset "Backward Pass" begin
|
@testset "Backward Pass" begin
|
||||||
d′ = deepcopy(d)
|
d′ = deepcopy(d)
|
||||||
@test dm(xs) ≈ d(xs)
|
@test dm(xs) ≈ d(xs)
|
||||||
|
Loading…
Reference in New Issue
Block a user