custom mxnet context
This commit is contained in:
parent
fe0bddd98d
commit
020ae616cc
@ -13,10 +13,10 @@ struct Graph
|
|||||||
stacks::Dict{Any,Any}
|
stacks::Dict{Any,Any}
|
||||||
end
|
end
|
||||||
|
|
||||||
function mxparams(ps)
|
function mxparams(ps, ctx)
|
||||||
params = Dict{Symbol,MXArray}()
|
params = Dict{Symbol,MXArray}()
|
||||||
for (name, param) in ps
|
for (name, param) in ps
|
||||||
params[name] = MXArray(size(param))
|
params[name] = MXArray(size(param), ctx)
|
||||||
end
|
end
|
||||||
return params
|
return params
|
||||||
end
|
end
|
||||||
@ -25,6 +25,7 @@ ndparams(d) = Dict{Symbol,mx.NDArray}(k => v.data for (k, v) in d)
|
|||||||
|
|
||||||
struct Exec
|
struct Exec
|
||||||
graph::Graph
|
graph::Graph
|
||||||
|
ctx::mx.Context
|
||||||
exec::mx.Executor
|
exec::mx.Executor
|
||||||
args::Dict{Symbol,MXArray}
|
args::Dict{Symbol,MXArray}
|
||||||
grads::Dict{Symbol,MXArray}
|
grads::Dict{Symbol,MXArray}
|
||||||
@ -41,16 +42,17 @@ mxungroup(x::Tuple, outs) = map(x -> mxungroup(x, outs), x)
|
|||||||
|
|
||||||
dictt(xs, ys) = Dict(zip(collectt(xs), collectt(ys)))
|
dictt(xs, ys) = Dict(zip(collectt(xs), collectt(ys)))
|
||||||
|
|
||||||
function executor(graph::Graph, input...)
|
function executor(graph::Graph, input...; ctx = mx.cpu())
|
||||||
shapecheckt(graph.input, input)
|
shapecheckt(graph.input, input)
|
||||||
args = merge(mxparams(graph.params), dictt(graph.input, mapt(d->MXArray(size(d)), input)))
|
args = merge(mxparams(graph.params, ctx), dictt(graph.input, mapt(d->MXArray(size(d), ctx), input)))
|
||||||
grads = filter((a, b) -> b isa Flux.Param, graph.params)
|
grads = filter((a, b) -> b isa Flux.Param, graph.params)
|
||||||
grads = merge(mxparams(grads), dictt(graph.input, mapt(d->MXArray(size(d)), input)))
|
grads = merge(mxparams(grads, ctx), dictt(graph.input, mapt(d->MXArray(size(d), ctx), input)))
|
||||||
exec = mx.bind(mxgroup(graph.output),
|
exec = mx.bind(mxgroup(graph.output),
|
||||||
|
context = ctx,
|
||||||
args = ndparams(args),
|
args = ndparams(args),
|
||||||
args_grad = ndparams(grads),
|
args_grad = ndparams(grads),
|
||||||
grad_req = mx.GRAD_ADD)
|
grad_req = mx.GRAD_ADD)
|
||||||
exec = Exec(graph, exec, args, grads, MXArray.(exec.outputs))
|
exec = Exec(graph, ctx, exec, args, grads, MXArray.(exec.outputs))
|
||||||
loadparams!(exec)
|
loadparams!(exec)
|
||||||
return exec
|
return exec
|
||||||
end
|
end
|
||||||
@ -63,7 +65,7 @@ end
|
|||||||
|
|
||||||
function Flux.back!(exec::Exec, Δ)
|
function Flux.back!(exec::Exec, Δ)
|
||||||
mapt(k -> exec.grads[k][:] = 0, exec.graph.input)
|
mapt(k -> exec.grads[k][:] = 0, exec.graph.input)
|
||||||
mx.backward(exec.exec, map(x -> MXArray(x).data, collectt(Δ)))
|
mx.backward(exec.exec, map(x -> MXArray(x, exec.ctx).data, collectt(Δ)))
|
||||||
mapt(k -> copy(exec.grads[k]), exec.graph.input)
|
mapt(k -> copy(exec.grads[k]), exec.graph.input)
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -79,22 +81,28 @@ function Flux.update!(exec::Exec, η)
|
|||||||
return exec
|
return exec
|
||||||
end
|
end
|
||||||
|
|
||||||
|
toctx(ctx::mx.Context) = ctx
|
||||||
|
toctx(c::Symbol) = c == :gpu ? mx.gpu() : mx.cpu()
|
||||||
|
|
||||||
# TODO: if `last` changes, update params appropriately
|
# TODO: if `last` changes, update params appropriately
|
||||||
|
|
||||||
mutable struct Model
|
mutable struct Model
|
||||||
model::Any
|
model::Any
|
||||||
|
ctx::mx.Context
|
||||||
execs::Dict{Tuple,Exec}
|
execs::Dict{Tuple,Exec}
|
||||||
graph::Graph
|
graph::Graph
|
||||||
last::Exec
|
last::Exec
|
||||||
Model(model) = new(model, Dict())
|
Model(model, ctx) = new(model, ctx, Dict())
|
||||||
end
|
end
|
||||||
|
|
||||||
mxnet(model) = Model(model)
|
mxnet(model, ctx = :cpu) = Model(model, toctx(ctx))
|
||||||
|
|
||||||
import Base: @get!
|
import Base: @get!
|
||||||
|
|
||||||
# TODO: dims having its own type would be useful
|
# TODO: dims having its own type would be useful
|
||||||
executor(m::Model, input...) = @get!(m.execs, mapt(size, input), executor(m.graph, input...))
|
executor(m::Model, input...) =
|
||||||
|
@get!(m.execs, mapt(size, input),
|
||||||
|
executor(m.graph, input...; ctx = m.ctx))
|
||||||
|
|
||||||
function (m::Model)(xs...)
|
function (m::Model)(xs...)
|
||||||
@mxerr m.graph.stacks begin
|
@mxerr m.graph.stacks begin
|
||||||
@ -134,10 +142,10 @@ function rewrite_softmax(model, name)
|
|||||||
return Flux.Capacitor(vertex(SoftmaxOutput(name), g[1]))
|
return Flux.Capacitor(vertex(SoftmaxOutput(name), g[1]))
|
||||||
end
|
end
|
||||||
|
|
||||||
function FeedForward(model; input = :data, label = :softmax, context = mx.cpu())
|
function FeedForward(model; input = :data, label = :softmax, ctx = mx.cpu())
|
||||||
model = rewrite_softmax(model, label)
|
model = rewrite_softmax(model, label)
|
||||||
graph = tograph(model, input, feedforward=true)
|
graph = tograph(model, input, feedforward=true)
|
||||||
ff = mx.FeedForward(graph.output, context = context)
|
ff = mx.FeedForward(graph.output, context = context)
|
||||||
isempty(graph.params) || (ff.arg_params = ndparams(mxparams(graph.params)))
|
isempty(graph.params) || (ff.arg_params = ndparams(mxparams(graph.params, ctx)))
|
||||||
return ff
|
return ff
|
||||||
end
|
end
|
||||||
|
@ -12,7 +12,8 @@ end
|
|||||||
|
|
||||||
MXArray(data::mx.NDArray) = MXArray(data, Array{Float32}(size(data)))
|
MXArray(data::mx.NDArray) = MXArray(data, Array{Float32}(size(data)))
|
||||||
|
|
||||||
MXArray(dims::Dims) = MXArray(mx.zeros(reverse(dims)))
|
# TODO: split cpu/gpu mxarrays
|
||||||
|
MXArray(dims::Dims, ctx = mx.cpu()) = MXArray(mx.zeros(reverse(dims), ctx))
|
||||||
|
|
||||||
Base.size(xs::MXArray) = reverse(size(xs.data))
|
Base.size(xs::MXArray) = reverse(size(xs.data))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user