2017-05-01 15:57:51 +00:00
|
|
|
using Flux: collectt, shapecheckt
|
|
|
|
|
2017-03-14 17:56:03 +00:00
|
|
|
struct AlterParam
|
2017-03-08 17:35:15 +00:00
|
|
|
param
|
|
|
|
load
|
|
|
|
store
|
2017-02-23 18:48:46 +00:00
|
|
|
end
|
|
|
|
|
2017-03-08 17:35:15 +00:00
|
|
|
Base.size(p::AlterParam) = size(p.load(p.param.x))
|
2017-03-12 14:51:55 +00:00
|
|
|
Base.copy!(xs, p::AlterParam) = copy!(xs, p.load(p.param.x))
|
2017-02-23 18:48:46 +00:00
|
|
|
|
2017-03-08 21:41:13 +00:00
|
|
|
function copyargs!(as, bs)
|
|
|
|
for id in intersect(keys(as), keys(bs))
|
|
|
|
copy!(as[id], bs[id])
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
2017-03-14 17:56:03 +00:00
|
|
|
struct Graph
|
2017-03-30 17:14:08 +00:00
|
|
|
input
|
2017-03-06 17:20:15 +00:00
|
|
|
output
|
2017-02-23 17:32:06 +00:00
|
|
|
params::Dict{Symbol,Any}
|
|
|
|
stacks::Dict{Any,Any}
|
|
|
|
end
|
|
|
|
|
2017-05-04 14:09:18 +00:00
|
|
|
function mxparams(ps)
|
2017-03-08 17:35:15 +00:00
|
|
|
params = Dict{Symbol,MXArray}()
|
2017-05-04 14:09:18 +00:00
|
|
|
for (name, param) in ps
|
2017-03-08 17:35:15 +00:00
|
|
|
params[name] = MXArray(size(param))
|
2017-02-23 18:48:46 +00:00
|
|
|
end
|
|
|
|
return params
|
|
|
|
end
|
|
|
|
|
2017-03-30 18:16:24 +00:00
|
|
|
ndparams(d) = Dict{Symbol,mx.NDArray}(k => v.data for (k, v) in d)
|
2017-02-23 21:42:34 +00:00
|
|
|
|
2017-03-14 17:56:03 +00:00
|
|
|
struct Exec
|
2017-02-23 17:32:06 +00:00
|
|
|
graph::Graph
|
2017-03-08 21:41:13 +00:00
|
|
|
exec::mx.Executor
|
2017-03-08 17:35:15 +00:00
|
|
|
args::Dict{Symbol,MXArray}
|
|
|
|
grads::Dict{Symbol,MXArray}
|
|
|
|
outs::Vector{MXArray}
|
2017-01-28 17:02:49 +00:00
|
|
|
end
|
|
|
|
|
2017-03-08 21:41:13 +00:00
|
|
|
loadparams!(exec::Exec) = copyargs!(exec.args, exec.graph.params)
|
|
|
|
storeparams!(exec::Exec) = copyargs!(exec.graph.params, exec.args)
|
2017-01-28 17:02:49 +00:00
|
|
|
|
2017-03-06 17:20:15 +00:00
|
|
|
mxgroup(x) = x
|
|
|
|
mxgroup(x::Tuple) = mx.Group(mxgroup.(x)...)
|
|
|
|
mxungroup(x, outs) = copy(shift!(outs))
|
|
|
|
mxungroup(x::Tuple, outs) = map(x -> mxungroup(x, outs), x)
|
|
|
|
|
2017-05-01 15:57:51 +00:00
|
|
|
dictt(xs, ys) = Dict(zip(collectt(xs), collectt(ys)))
|
2017-03-30 18:16:24 +00:00
|
|
|
|
|
|
|
function executor(graph::Graph, input...)
|
2017-05-01 15:57:51 +00:00
|
|
|
shapecheckt(graph.input, input)
|
2017-05-04 14:09:18 +00:00
|
|
|
args = merge(mxparams(graph.params), dictt(graph.input, mapt(d->MXArray(size(d)), input)))
|
|
|
|
grads = filter((a, b) -> b isa Flux.Param, graph.params)
|
|
|
|
grads = merge(mxparams(grads), dictt(graph.input, mapt(d->MXArray(size(d)), input)))
|
2017-03-08 21:41:13 +00:00
|
|
|
exec = mx.bind(mxgroup(graph.output),
|
|
|
|
args = ndparams(args),
|
|
|
|
args_grad = ndparams(grads),
|
|
|
|
grad_req = mx.GRAD_ADD)
|
|
|
|
exec = Exec(graph, exec, args, grads, MXArray.(exec.outputs))
|
|
|
|
loadparams!(exec)
|
|
|
|
return exec
|
2017-01-28 17:02:49 +00:00
|
|
|
end
|
|
|
|
|
2017-03-30 18:25:54 +00:00
|
|
|
function (exec::Exec)(input...)
|
|
|
|
foreach(kv -> copy!(exec.args[kv[1]], kv[2]), dictt(exec.graph.input, input))
|
2017-03-08 21:41:13 +00:00
|
|
|
mx.forward(exec.exec, is_train = true)
|
|
|
|
mxungroup(exec.graph.output, copy(exec.outs))
|
2017-01-28 17:02:49 +00:00
|
|
|
end
|
|
|
|
|
2017-03-08 21:41:13 +00:00
|
|
|
function Flux.back!(exec::Exec, Δ)
|
2017-03-30 18:36:59 +00:00
|
|
|
mapt(k -> exec.grads[k][:] = 0, exec.graph.input)
|
2017-04-19 16:13:57 +00:00
|
|
|
mx.backward(exec.exec, map(x -> MXArray(x).data, collectt(Δ)))
|
2017-03-30 18:36:59 +00:00
|
|
|
mapt(k -> copy(exec.grads[k]), exec.graph.input)
|
2017-01-28 17:02:49 +00:00
|
|
|
end
|
|
|
|
|
2017-03-08 21:41:13 +00:00
|
|
|
function Flux.update!(exec::Exec, η)
|
|
|
|
for (arg, grad) in zip(exec.exec.arg_arrays, exec.exec.grad_arrays)
|
2017-05-04 14:09:18 +00:00
|
|
|
grad == nothing && continue
|
2017-01-28 17:02:49 +00:00
|
|
|
mx.@nd_as_jl rw = (arg, grad) begin
|
|
|
|
arg .-= grad .* η
|
|
|
|
grad[:] = 0
|
|
|
|
end
|
|
|
|
end
|
2017-03-08 21:41:13 +00:00
|
|
|
storeparams!(exec)
|
|
|
|
return exec
|
|
|
|
end
|
|
|
|
|
|
|
|
# TODO: if `last` changes, update params appropriately
|
|
|
|
|
2017-03-14 17:56:03 +00:00
|
|
|
mutable struct Model <: Flux.Model
|
2017-03-08 21:41:13 +00:00
|
|
|
model::Any
|
|
|
|
execs::Dict{Tuple,Exec}
|
2017-03-30 17:14:08 +00:00
|
|
|
graph::Graph
|
2017-03-08 21:41:13 +00:00
|
|
|
last::Exec
|
2017-03-30 17:14:08 +00:00
|
|
|
Model(model) = new(model, Dict())
|
2017-01-28 17:02:49 +00:00
|
|
|
end
|
|
|
|
|
2017-03-30 17:14:08 +00:00
|
|
|
mxnet(model) = Model(model)
|
2017-03-08 21:41:13 +00:00
|
|
|
|
|
|
|
import Base: @get!
|
|
|
|
|
2017-03-30 18:50:03 +00:00
|
|
|
# TODO: dims having its own type would be useful
|
|
|
|
executor(m::Model, input...) = @get!(m.execs, mapt(size, input), executor(m.graph, input...))
|
2017-03-08 21:41:13 +00:00
|
|
|
|
2017-03-30 19:05:18 +00:00
|
|
|
function (m::Model)(xs...)
|
2017-04-18 20:04:21 +00:00
|
|
|
@mxerr m.graph.stacks begin
|
|
|
|
!isdefined(m, :graph) &&
|
|
|
|
(m.graph = tograph(m.model, mapt(_ -> gensym("input"), xs)...))
|
|
|
|
m.last = exec = executor(m, xs...)
|
|
|
|
exec(xs...)
|
|
|
|
end
|
2017-03-08 21:41:13 +00:00
|
|
|
end
|
|
|
|
|
2017-03-30 18:50:03 +00:00
|
|
|
function Flux.back!(m::Model, Δ, xs...)
|
2017-04-18 20:04:21 +00:00
|
|
|
m.last = exec = m.execs[mapt(size, xs)]
|
|
|
|
back!(exec, Δ)
|
2017-03-08 21:41:13 +00:00
|
|
|
end
|
|
|
|
|
|
|
|
Flux.update!(m::Model, η) = (update!(m.last, η); m)
|
|
|
|
|
2017-03-30 19:05:18 +00:00
|
|
|
# Recurrent Models
|
|
|
|
|
|
|
|
using Flux: Stateful, SeqModel
|
|
|
|
|
2017-04-26 16:42:47 +00:00
|
|
|
mxnet(m::Stateful) = Stateful(mxnet(m.model), m.istate, m.ostate)
|
2017-03-30 19:05:18 +00:00
|
|
|
mxnet(m::SeqModel) = SeqModel(mxnet(m.model), m.steps)
|
|
|
|
|
2017-01-28 17:02:49 +00:00
|
|
|
# MX FeedForward interface
|
|
|
|
|
2017-03-14 17:56:03 +00:00
|
|
|
struct SoftmaxOutput
|
2017-01-28 17:02:49 +00:00
|
|
|
name::Symbol
|
|
|
|
end
|
|
|
|
|
2017-02-23 16:58:10 +00:00
|
|
|
graph(s::SoftmaxOutput, xs) = mx.SoftmaxOutput(xs, name = s.name)
|
2017-01-28 17:02:49 +00:00
|
|
|
|
|
|
|
function rewrite_softmax(model, name)
|
|
|
|
model == softmax && return SoftmaxOutput(name)
|
|
|
|
g = Flux.graph(model)
|
2017-02-20 19:35:32 +00:00
|
|
|
(g == nothing || g.value ≠ softmax || DataFlow.nin(g) ≠ 1) && error("mx.FeedForward models must end with `softmax`")
|
2017-01-28 17:02:49 +00:00
|
|
|
return Flux.Capacitor(vertex(SoftmaxOutput(name), g[1]))
|
|
|
|
end
|
|
|
|
|
2017-01-30 18:05:05 +00:00
|
|
|
function mx.FeedForward(model::Flux.Model; input = :data, label = :softmax, context = mx.cpu())
|
2017-01-28 17:02:49 +00:00
|
|
|
model = rewrite_softmax(model, label)
|
2017-03-30 17:14:08 +00:00
|
|
|
graph = tograph(model, input, feedforward=true)
|
2017-03-06 17:20:15 +00:00
|
|
|
ff = mx.FeedForward(graph.output, context = context)
|
2017-05-04 14:09:18 +00:00
|
|
|
isempty(graph.params) || (ff.arg_params = ndparams(mxparams(graph.params)))
|
2017-01-28 17:02:49 +00:00
|
|
|
return ff
|
|
|
|
end
|