2017-03-06 17:20:15 +00:00
|
|
|
using Flux: batchone, unbatchone, rebatch
|
2017-01-28 17:02:49 +00:00
|
|
|
|
2017-02-23 18:48:46 +00:00
|
|
|
type AlterParam
|
2017-03-08 17:35:15 +00:00
|
|
|
param
|
|
|
|
load
|
|
|
|
store
|
2017-02-23 18:48:46 +00:00
|
|
|
end
|
|
|
|
|
2017-03-08 17:35:15 +00:00
|
|
|
Base.size(p::AlterParam) = size(p.load(p.param.x))
|
2017-02-23 18:48:46 +00:00
|
|
|
|
2017-02-23 17:32:06 +00:00
|
|
|
type Graph
|
2017-03-06 17:20:15 +00:00
|
|
|
output
|
2017-02-23 17:32:06 +00:00
|
|
|
params::Dict{Symbol,Any}
|
|
|
|
stacks::Dict{Any,Any}
|
|
|
|
end
|
|
|
|
|
2017-02-23 18:48:46 +00:00
|
|
|
function mxparams(g::Graph)
|
2017-03-08 17:35:15 +00:00
|
|
|
params = Dict{Symbol,MXArray}()
|
2017-02-23 18:48:46 +00:00
|
|
|
for (name, param) in g.params
|
2017-03-08 17:35:15 +00:00
|
|
|
params[name] = MXArray(size(param))
|
2017-02-23 18:48:46 +00:00
|
|
|
end
|
|
|
|
return params
|
|
|
|
end
|
|
|
|
|
2017-03-08 17:35:15 +00:00
|
|
|
function copyargs!(as, bs)
|
|
|
|
for id in intersect(keys(as), keys(bs))
|
|
|
|
copy!(as[id], bs[id])
|
2017-02-23 21:42:34 +00:00
|
|
|
end
|
|
|
|
end
|
|
|
|
|
2017-03-08 17:35:15 +00:00
|
|
|
ndparams(d::Dict{Symbol,MXArray}) = Dict(k => v.data for (k, v) in d)
|
2017-02-23 21:42:34 +00:00
|
|
|
|
2017-01-30 18:05:05 +00:00
|
|
|
type Model <: Flux.Model
|
2017-01-28 17:02:49 +00:00
|
|
|
model::Any
|
2017-02-23 17:32:06 +00:00
|
|
|
graph::Graph
|
2017-03-08 17:35:15 +00:00
|
|
|
args::Dict{Symbol,MXArray}
|
|
|
|
grads::Dict{Symbol,MXArray}
|
|
|
|
outs::Vector{MXArray}
|
2017-01-28 17:02:49 +00:00
|
|
|
exec::mx.Executor
|
|
|
|
end
|
|
|
|
|
2017-03-08 17:35:15 +00:00
|
|
|
loadparams!(model::Model) = copyargs!(model.args, model.graph.params)
|
|
|
|
storeparams!(model::Model) = copyargs!(model.graph.params, model.args)
|
2017-01-28 17:02:49 +00:00
|
|
|
|
2017-03-06 17:20:15 +00:00
|
|
|
mxgroup(x) = x
|
|
|
|
mxgroup(x::Tuple) = mx.Group(mxgroup.(x)...)
|
|
|
|
mxungroup(x, outs) = copy(shift!(outs))
|
|
|
|
mxungroup(x::Tuple, outs) = map(x -> mxungroup(x, outs), x)
|
|
|
|
|
2017-01-30 18:05:05 +00:00
|
|
|
function mxnet(model::Flux.Model, input)
|
2017-02-23 17:32:06 +00:00
|
|
|
graph = tograph(model, mx.Variable(:input))
|
2017-03-08 17:35:15 +00:00
|
|
|
args = merge(mxparams(graph), Dict(:input => MXArray(input)))
|
|
|
|
grads = merge(mxparams(graph), Dict(:input => MXArray(input)))
|
|
|
|
exec = @mxerr graph.stacks mx.bind(mxgroup(graph.output),
|
|
|
|
args = ndparams(args),
|
|
|
|
args_grad = ndparams(grads),
|
|
|
|
grad_req = mx.GRAD_ADD)
|
|
|
|
model = Model(model, graph, args, grads, MXArray.(exec.outputs), exec)
|
2017-01-28 17:02:49 +00:00
|
|
|
loadparams!(model)
|
|
|
|
return model
|
|
|
|
end
|
|
|
|
|
2017-01-30 18:05:05 +00:00
|
|
|
function runmodel(model::Model, input)
|
2017-03-08 17:35:15 +00:00
|
|
|
copy!(model.args[:input], input)
|
2017-01-28 17:02:49 +00:00
|
|
|
mx.forward(model.exec, is_train = true)
|
2017-03-08 17:35:15 +00:00
|
|
|
mxungroup(model.graph.output, copy(model.outs))
|
2017-01-28 17:02:49 +00:00
|
|
|
end
|
|
|
|
|
2017-03-08 17:35:15 +00:00
|
|
|
(m::Model)(x::Batch) = rebatch(runmodel(m, rawbatch(x)))
|
2017-01-30 18:05:15 +00:00
|
|
|
|
2017-03-06 17:20:15 +00:00
|
|
|
(m::Model)(x) = unbatchone(m(batchone(x)))
|
2017-01-30 18:05:15 +00:00
|
|
|
|
2017-02-23 21:06:46 +00:00
|
|
|
function runback!(model::Model, Δ)
|
|
|
|
model.grads[:input][:] = 0
|
2017-03-08 17:35:15 +00:00
|
|
|
mx.backward(model.exec, MXArray(Δ).data)
|
2017-01-29 11:28:22 +00:00
|
|
|
copy(model.grads[:input])
|
2017-01-28 17:02:49 +00:00
|
|
|
end
|
|
|
|
|
2017-03-08 17:35:15 +00:00
|
|
|
Flux.back!(m::Model, Δ::Batch, x) = rebatch(runback!(m, rawbatch(Δ)))
|
2017-02-23 21:06:46 +00:00
|
|
|
|
|
|
|
Flux.back!(m::Model, Δ, x) = first(Flux.back!(m, batchone(Δ), x))
|
|
|
|
|
2017-01-30 18:05:05 +00:00
|
|
|
function Flux.update!(model::Model, η)
|
2017-01-28 17:02:49 +00:00
|
|
|
for (arg, grad) in zip(model.exec.arg_arrays, model.exec.grad_arrays)
|
|
|
|
mx.@nd_as_jl rw = (arg, grad) begin
|
|
|
|
arg .-= grad .* η
|
|
|
|
grad[:] = 0
|
|
|
|
end
|
|
|
|
end
|
2017-02-23 21:42:34 +00:00
|
|
|
storeparams!(model)
|
2017-01-28 17:02:49 +00:00
|
|
|
return model
|
|
|
|
end
|
|
|
|
|
|
|
|
# MX FeedForward interface
|
|
|
|
|
|
|
|
type SoftmaxOutput
|
|
|
|
name::Symbol
|
|
|
|
end
|
|
|
|
|
2017-02-23 16:58:10 +00:00
|
|
|
graph(s::SoftmaxOutput, xs) = mx.SoftmaxOutput(xs, name = s.name)
|
2017-01-28 17:02:49 +00:00
|
|
|
|
|
|
|
function rewrite_softmax(model, name)
|
|
|
|
model == softmax && return SoftmaxOutput(name)
|
|
|
|
g = Flux.graph(model)
|
2017-02-20 19:35:32 +00:00
|
|
|
(g == nothing || g.value ≠ softmax || DataFlow.nin(g) ≠ 1) && error("mx.FeedForward models must end with `softmax`")
|
2017-01-28 17:02:49 +00:00
|
|
|
return Flux.Capacitor(vertex(SoftmaxOutput(name), g[1]))
|
|
|
|
end
|
|
|
|
|
2017-01-30 18:05:05 +00:00
|
|
|
function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, context = mx.cpu())
|
2017-01-28 17:02:49 +00:00
|
|
|
model = rewrite_softmax(model, label)
|
2017-02-23 17:32:06 +00:00
|
|
|
graph = tograph(model, mx.Variable(input), feedforward=true)
|
2017-03-06 17:20:15 +00:00
|
|
|
ff = mx.FeedForward(graph.output, context = context)
|
2017-03-08 17:35:15 +00:00
|
|
|
isempty(graph.params) || (ff.arg_params = ndparams(mxparams(graph)))
|
2017-01-28 17:02:49 +00:00
|
|
|
return ff
|
|
|
|
end
|