Flux.jl/src/backend/mxnet/model.jl

151 lines
3.7 KiB
Julia
Raw Normal View History

2017-03-09 00:13:26 +00:00
using Flux: runrawbatched
2017-01-28 17:02:49 +00:00
2017-03-14 17:56:03 +00:00
struct AlterParam
2017-03-08 17:35:15 +00:00
param
load
store
2017-02-23 18:48:46 +00:00
end
2017-03-08 17:35:15 +00:00
Base.size(p::AlterParam) = size(p.load(p.param.x))
2017-03-12 14:51:55 +00:00
Base.copy!(xs, p::AlterParam) = copy!(xs, p.load(p.param.x))
2017-02-23 18:48:46 +00:00
2017-03-08 21:41:13 +00:00
function copyargs!(as, bs)
for id in intersect(keys(as), keys(bs))
copy!(as[id], bs[id])
end
end
2017-03-14 17:56:03 +00:00
struct Graph
2017-03-30 17:14:08 +00:00
input
2017-03-06 17:20:15 +00:00
output
2017-02-23 17:32:06 +00:00
params::Dict{Symbol,Any}
stacks::Dict{Any,Any}
end
2017-02-23 18:48:46 +00:00
function mxparams(g::Graph)
2017-03-08 17:35:15 +00:00
params = Dict{Symbol,MXArray}()
2017-02-23 18:48:46 +00:00
for (name, param) in g.params
2017-03-08 17:35:15 +00:00
params[name] = MXArray(size(param))
2017-02-23 18:48:46 +00:00
end
return params
end
2017-03-30 18:16:24 +00:00
ndparams(d) = Dict{Symbol,mx.NDArray}(k => v.data for (k, v) in d)
2017-02-23 21:42:34 +00:00
2017-03-14 17:56:03 +00:00
struct Exec
2017-02-23 17:32:06 +00:00
graph::Graph
2017-03-08 21:41:13 +00:00
exec::mx.Executor
2017-03-08 17:35:15 +00:00
args::Dict{Symbol,MXArray}
grads::Dict{Symbol,MXArray}
outs::Vector{MXArray}
2017-01-28 17:02:49 +00:00
end
2017-03-08 21:41:13 +00:00
loadparams!(exec::Exec) = copyargs!(exec.args, exec.graph.params)
storeparams!(exec::Exec) = copyargs!(exec.graph.params, exec.args)
2017-01-28 17:02:49 +00:00
2017-03-06 17:20:15 +00:00
mxgroup(x) = x
mxgroup(x::Tuple) = mx.Group(mxgroup.(x)...)
mxungroup(x, outs) = copy(shift!(outs))
mxungroup(x::Tuple, outs) = map(x -> mxungroup(x, outs), x)
2017-03-30 18:16:24 +00:00
function dictt(ks::Tuple, vs, d = Dict())
for i = 1:length(ks)
dictt(ks[i], vs[i], d)
end
return d
end
dictt(k, v, d = Dict()) = (d[k] = v; d)
function executor(graph::Graph, input...)
args = merge(mxparams(graph), dictt(graph.input, mapt(d->MXArray(size(d)), input)))
grads = merge(mxparams(graph), dictt(graph.input, mapt(d->MXArray(size(d)), input)))
2017-03-08 21:41:13 +00:00
exec = mx.bind(mxgroup(graph.output),
args = ndparams(args),
args_grad = ndparams(grads),
grad_req = mx.GRAD_ADD)
exec = Exec(graph, exec, args, grads, MXArray.(exec.outputs))
loadparams!(exec)
return exec
2017-01-28 17:02:49 +00:00
end
2017-03-30 18:25:54 +00:00
function (exec::Exec)(input...)
foreach(kv -> copy!(exec.args[kv[1]], kv[2]), dictt(exec.graph.input, input))
2017-03-08 21:41:13 +00:00
mx.forward(exec.exec, is_train = true)
mxungroup(exec.graph.output, copy(exec.outs))
2017-01-28 17:02:49 +00:00
end
2017-03-08 21:41:13 +00:00
function Flux.back!(exec::Exec, Δ)
2017-03-30 17:14:08 +00:00
exec.grads[exec.graph.input[1]][:] = 0
2017-03-08 21:41:13 +00:00
mx.backward(exec.exec, MXArray(Δ).data)
2017-03-30 17:14:08 +00:00
copy(exec.grads[exec.graph.input[1]])
2017-01-28 17:02:49 +00:00
end
2017-03-08 21:41:13 +00:00
function Flux.update!(exec::Exec, η)
for (arg, grad) in zip(exec.exec.arg_arrays, exec.exec.grad_arrays)
2017-01-28 17:02:49 +00:00
mx.@nd_as_jl rw = (arg, grad) begin
arg .-= grad .* η
grad[:] = 0
end
end
2017-03-08 21:41:13 +00:00
storeparams!(exec)
return exec
end
# TODO: if `last` changes, update params appropriately
2017-03-14 17:56:03 +00:00
mutable struct Model <: Flux.Model
2017-03-08 21:41:13 +00:00
model::Any
execs::Dict{Tuple,Exec}
2017-03-30 17:14:08 +00:00
graph::Graph
2017-03-08 21:41:13 +00:00
last::Exec
2017-03-30 17:14:08 +00:00
Model(model) = new(model, Dict())
2017-01-28 17:02:49 +00:00
end
2017-03-30 17:14:08 +00:00
mxnet(model) = Model(model)
2017-03-08 21:41:13 +00:00
import Base: @get!
2017-03-30 18:16:24 +00:00
executor(m::Model, input) = @get!(m.execs, size(input), executor(m.graph, input))
2017-03-08 21:41:13 +00:00
2017-03-09 00:13:26 +00:00
function (m::Model)(x)
2017-03-30 17:14:08 +00:00
!isdefined(m, :graph) &&
(m.graph = tograph(m.model, mapt(_ -> gensym("input"), input)))
2017-03-12 18:33:47 +00:00
@mxerr m.graph.stacks runrawbatched(x) do x
2017-03-30 18:16:24 +00:00
m.last = exec = executor(m, x)
2017-03-09 00:13:26 +00:00
exec(x)
end
2017-03-08 21:41:13 +00:00
end
2017-03-09 00:13:26 +00:00
function Flux.back!(m::Model, Δ, x)
runrawbatched(Δ, x) do Δ, x
m.last = exec = m.execs[size(x)]
back!(exec, Δ)
end
2017-03-08 21:41:13 +00:00
end
Flux.update!(m::Model, η) = (update!(m.last, η); m)
2017-01-28 17:02:49 +00:00
# MX FeedForward interface
2017-03-14 17:56:03 +00:00
struct SoftmaxOutput
2017-01-28 17:02:49 +00:00
name::Symbol
end
2017-02-23 16:58:10 +00:00
graph(s::SoftmaxOutput, xs) = mx.SoftmaxOutput(xs, name = s.name)
2017-01-28 17:02:49 +00:00
function rewrite_softmax(model, name)
model == softmax && return SoftmaxOutput(name)
g = Flux.graph(model)
2017-02-20 19:35:32 +00:00
(g == nothing || g.value softmax || DataFlow.nin(g) 1) && error("mx.FeedForward models must end with `softmax`")
2017-01-28 17:02:49 +00:00
return Flux.Capacitor(vertex(SoftmaxOutput(name), g[1]))
end
2017-01-30 18:05:05 +00:00
function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, context = mx.cpu())
2017-01-28 17:02:49 +00:00
model = rewrite_softmax(model, label)
2017-03-30 17:14:08 +00:00
graph = tograph(model, input, feedforward=true)
2017-03-06 17:20:15 +00:00
ff = mx.FeedForward(graph.output, context = context)
2017-03-08 17:35:15 +00:00
isempty(graph.params) || (ff.arg_params = ndparams(mxparams(graph)))
2017-01-28 17:02:49 +00:00
return ff
end