Flux.jl/src/backend/mxnet/model.jl

117 lines
3.2 KiB
Julia
Raw Normal View History

2017-01-30 18:05:15 +00:00
using Flux: batchone, rebatch
2017-01-28 17:02:49 +00:00
2017-02-23 18:48:46 +00:00
# MNet batches on last dimension
rebatch_last(xs) = permutedims(xs, (2:ndims(xs)..., 1))
rebatch_first(xs) = permutedims(xs, (ndims(xs), 1:ndims(xs)-1...))
paramvalue(p) = rebatch_last(p)
paramvalue(p::Flux.Param) = paramvalue(p.x)
# Basically a kludge to make Affine work
# Hopefully will go away with more inference
type AlterParam
param::Flux.Param
strip::Bool
rebatch::Bool
end
function paramvalue(p::AlterParam)
val = p.rebatch ? paramvalue(p.param) : p.param.x
p.strip ? squeeze(val, 1) : val
end
2017-02-23 17:32:06 +00:00
type Graph
node::mx.SymbolicNode
params::Dict{Symbol,Any}
stacks::Dict{Any,Any}
end
2017-02-23 18:48:46 +00:00
function mxparams(g::Graph)
params = Dict{Symbol,mx.NDArray}()
for (name, param) in g.params
params[name] = mx.zeros(size(paramvalue(param)))
end
return params
end
2017-01-30 18:05:05 +00:00
type Model <: Flux.Model
2017-01-28 17:02:49 +00:00
model::Any
2017-02-23 17:32:06 +00:00
graph::Graph
2017-01-28 17:02:49 +00:00
grads::Dict{Symbol,Any}
exec::mx.Executor
end
2017-01-30 18:05:05 +00:00
function loadparams!(model::Model)
2017-01-28 17:02:49 +00:00
for (name, arr) in model.exec.arg_dict
2017-02-23 18:48:46 +00:00
haskey(model.graph.params, name) && copy!(arr, paramvalue(model.graph.params[name]))
2017-01-28 17:02:49 +00:00
end
return model
end
2017-01-30 18:05:05 +00:00
function mxnet(model::Flux.Model, input)
2017-02-23 17:32:06 +00:00
graph = tograph(model, mx.Variable(:input))
2017-02-23 21:06:46 +00:00
args = merge(mxparams(graph), Dict(:input => mx.zeros(input)))
grads = merge(mxparams(graph), Dict(:input => mx.zeros(input)))
2017-02-23 17:32:06 +00:00
model = @mxerr graph.stacks Model(model, graph, grads,
mx.bind(graph.node, args = args,
args_grad = grads,
grad_req = mx.GRAD_ADD))
2017-01-28 17:02:49 +00:00
loadparams!(model)
return model
end
2017-01-30 18:05:05 +00:00
function runmodel(model::Model, input)
2017-01-28 17:37:22 +00:00
copy!(model.exec.arg_dict[:input], input)
2017-01-28 17:02:49 +00:00
mx.forward(model.exec, is_train = true)
2017-01-29 11:28:22 +00:00
copy(model.exec.outputs[1])
2017-01-28 17:02:49 +00:00
end
2017-02-21 12:58:31 +00:00
(m::Model)(x::Batch) = rebatch(rebatch_first(runmodel(m, rebatch_last(rawbatch(x)))))
2017-01-30 18:05:15 +00:00
(m::Model)(x) = first(m(batchone(x)))
2017-02-23 21:06:46 +00:00
tond(xs::AArray) = copy!(mx.zeros(size(xs)), xs)
function runback!(model::Model, Δ)
model.grads[:input][:] = 0
2017-01-28 17:02:49 +00:00
mx.backward(model.exec, tond(Δ))
2017-01-29 11:28:22 +00:00
copy(model.grads[:input])
2017-01-28 17:02:49 +00:00
end
2017-02-23 21:06:46 +00:00
Flux.back!(m::Model, Δ::Batch, x) = rebatch(rebatch_first(runback!(m, rebatch_last(rawbatch(Δ)))))
Flux.back!(m::Model, Δ, x) = first(Flux.back!(m, batchone(Δ), x))
2017-01-30 18:05:05 +00:00
function Flux.update!(model::Model, η)
2017-01-28 17:02:49 +00:00
for (arg, grad) in zip(model.exec.arg_arrays, model.exec.grad_arrays)
mx.@nd_as_jl rw = (arg, grad) begin
arg .-= grad .* η
grad[:] = 0
end
end
return model
end
# MX FeedForward interface
type SoftmaxOutput
name::Symbol
end
2017-02-23 16:58:10 +00:00
graph(s::SoftmaxOutput, xs) = mx.SoftmaxOutput(xs, name = s.name)
2017-01-28 17:02:49 +00:00
function rewrite_softmax(model, name)
model == softmax && return SoftmaxOutput(name)
g = Flux.graph(model)
2017-02-20 19:35:32 +00:00
(g == nothing || g.value softmax || DataFlow.nin(g) 1) && error("mx.FeedForward models must end with `softmax`")
2017-01-28 17:02:49 +00:00
return Flux.Capacitor(vertex(SoftmaxOutput(name), g[1]))
end
2017-01-30 18:05:05 +00:00
function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, context = mx.cpu())
2017-01-28 17:02:49 +00:00
model = rewrite_softmax(model, label)
2017-02-23 17:32:06 +00:00
graph = tograph(model, mx.Variable(input), feedforward=true)
ff = mx.FeedForward(graph.node, context = context)
2017-02-23 18:48:46 +00:00
isempty(graph.params) || (ff.arg_params = mxparams(graph))
2017-01-28 17:02:49 +00:00
return ff
end