remove mxnet for now
This commit is contained in:
parent
18502158f0
commit
ee0c5ae14e
|
@ -1,9 +1,5 @@
|
|||
# TODO: load backends lazily
|
||||
|
||||
# include("mxnet/mxnet.jl")
|
||||
# using .MX
|
||||
# export mxnet
|
||||
|
||||
include("tensorflow/tensorflow.jl")
|
||||
using .TF
|
||||
export tf
|
||||
|
|
|
@ -1,75 +0,0 @@
|
|||
cvalue(x) = x
|
||||
cvalue(c::Constant) = c.value
|
||||
cvalue(v::Vertex) = cvalue(value(v))
|
||||
|
||||
graph(vars, model, args...) = node(model, args...)
|
||||
|
||||
graph(vars, x::mx.SymbolicNode) = x
|
||||
|
||||
# TODO: detect parameters used more than once
|
||||
function graph{T<:AArray}(vars, p::Flux.Param{T})
|
||||
value = p.x
|
||||
id = gensym()
|
||||
vars[id] = value
|
||||
return mx.Variable(id)
|
||||
end
|
||||
|
||||
function graph(vars, model::Model, args...)
|
||||
g = Flux.graph(model)
|
||||
g = Flow.mapconst(g) do x
|
||||
!isa(x, Flux.ModelInput) ? x :
|
||||
isa(x.name, Integer) ? args[x.name] : getfield(model, x.name)
|
||||
end
|
||||
postwalk(g) do v
|
||||
vertex(graph(vars, cvalue(v), cvalue.(inputs(v))...))
|
||||
end |> value
|
||||
end
|
||||
|
||||
type SoftmaxOutput
|
||||
name::Symbol
|
||||
end
|
||||
|
||||
function rewrite_softmax(model, name)
|
||||
model == softmax && return SoftmaxOutput(name)
|
||||
g = Flux.graph(model)
|
||||
(g == nothing || value(g) ≠ softmax || Flow.nin(g) ≠ 1) && error("mx.FeedForward models must end with `softmax`")
|
||||
return Flux.Capacitor(vertex(SoftmaxOutput(name), g[1]))
|
||||
end
|
||||
|
||||
# Built-in implemenations
|
||||
|
||||
node(::typeof(*), args...) = mx.dot(reverse(args)...)
|
||||
node(::typeof(+), args...) = mx.broadcast_plus(args...)
|
||||
node(::typeof(σ), x) = mx.Activation(data = x, act_type = :sigmoid)
|
||||
node(::typeof(relu), x) = mx.Activation(data = x, act_type=:relu)
|
||||
node(::typeof(tanh), x) = mx.Activation(data = x, act_type=:tanh)
|
||||
node(::typeof(flatten), x) = mx.Flatten(data = x)
|
||||
|
||||
node(::typeof(softmax), xs) =
|
||||
mx.broadcast_div(exp(xs), mx.Reshape(mx.sum(exp(xs)), shape = (1,1)))
|
||||
|
||||
node(s::SoftmaxOutput, xs) = mx.SoftmaxOutput(data = xs, name = s.name)
|
||||
|
||||
node(::typeof(cat), dim::Integer, a...) = mx.Concat(a..., dim = dim)
|
||||
node(::typeof(vcat), a...) = node(cat, 1, a...)
|
||||
|
||||
graph(vars, ::Input, x) = x
|
||||
|
||||
graph(vars, c::Conv, x) =
|
||||
mx.Convolution(data = x,
|
||||
kernel = c.size,
|
||||
num_filter = c.features,
|
||||
stride = c.stride)
|
||||
|
||||
graph(vars, p::MaxPool, x) =
|
||||
mx.Pooling(data = x,
|
||||
pool_type = :max,
|
||||
kernel = p.size,
|
||||
stride = p.stride)
|
||||
|
||||
# TODO: fix the initialisation issue
|
||||
graph(vars, d::Dense, x) =
|
||||
mx.FullyConnected(data = x,
|
||||
num_hidden = size(d.W.x, 1),
|
||||
weight = graph(vars, d.W),
|
||||
bias = graph(vars, d.b))
|
|
@ -1,97 +0,0 @@
|
|||
using MacroTools
|
||||
|
||||
type MXModel <: Model
|
||||
model::Any
|
||||
params::Dict{Symbol,Any}
|
||||
grads::Dict{Symbol,Any}
|
||||
exec::mx.Executor
|
||||
end
|
||||
|
||||
Base.show(io::IO, m::MXModel) =
|
||||
print(io, "MXModel($(m.model))")
|
||||
|
||||
mxdims(dims::NTuple) = reverse(dims)
|
||||
|
||||
mxdims(n::Integer) = mxdims((n,))
|
||||
|
||||
function tond!(nd::mx.NDArray, xs::AArray)
|
||||
mx.copy_ignore_shape!(nd, xs')
|
||||
nd
|
||||
end
|
||||
|
||||
tond(xs::AArray) = tond!(mx.zeros(mxdims(size(xs))), xs)
|
||||
|
||||
fromnd(xs::mx.NDArray) = copy(xs)'
|
||||
|
||||
ndzero!(xs::mx.NDArray) = copy!(xs, mx.zeros(size(xs)))
|
||||
|
||||
function mxargs(args)
|
||||
map(args) do kv
|
||||
arg, value = kv
|
||||
arg => tond(value)
|
||||
end
|
||||
end
|
||||
|
||||
function mxgrads(mxargs)
|
||||
map(mxargs) do kv
|
||||
arg, value = kv
|
||||
arg => mx.zeros(size(value))
|
||||
end
|
||||
end
|
||||
|
||||
function load!(model::MXModel)
|
||||
for (name, arr) in model.exec.arg_dict
|
||||
haskey(model.params, name) && tond!(arr, model.params[name])
|
||||
end
|
||||
return model
|
||||
end
|
||||
|
||||
function mxgraph(model, input; vars = true)
|
||||
vars = vars ? Dict{Symbol,Any}() : nothing
|
||||
node = graph(vars, model, mx.Variable(input))
|
||||
return node, vars
|
||||
end
|
||||
|
||||
function mxnet(model::Model, input)
|
||||
node, vars = mxgraph(model, :input)
|
||||
args = merge(mxargs(vars), Dict(:input => mx.zeros(mxdims(input))))
|
||||
grads = mxgrads(args)
|
||||
model = MXModel(model, vars, grads,
|
||||
mx.bind(node, args = args,
|
||||
args_grad = grads,
|
||||
grad_req = mx.GRAD_ADD))
|
||||
load!(model)
|
||||
return model
|
||||
end
|
||||
|
||||
function (model::MXModel)(input)
|
||||
tond!(model.exec.arg_dict[:input], input)
|
||||
mx.forward(model.exec, is_train = true)
|
||||
fromnd(model.exec.outputs[1])
|
||||
end
|
||||
|
||||
function Flux.back!(model::MXModel, Δ, x)
|
||||
ndzero!(model.grads[:input])
|
||||
mx.backward(model.exec, tond(Δ))
|
||||
fromnd(model.grads[:input])
|
||||
end
|
||||
|
||||
function Flux.update!(model::MXModel, η)
|
||||
for (arg, grad) in zip(model.exec.arg_arrays, model.exec.grad_arrays)
|
||||
mx.@nd_as_jl rw = (arg, grad) begin
|
||||
arg .-= grad .* η
|
||||
grad[:] = 0
|
||||
end
|
||||
end
|
||||
return model
|
||||
end
|
||||
|
||||
# MX FeedForward interface
|
||||
|
||||
function mx.FeedForward(model::Model; input = :data, label = :softmax, context = mx.cpu())
|
||||
model = rewrite_softmax(model, label)
|
||||
node, vars = mxgraph(model, input)
|
||||
ff = mx.FeedForward(node, context = context)
|
||||
ff.arg_params = mxargs(vars)
|
||||
return ff
|
||||
end
|
|
@ -1,10 +0,0 @@
|
|||
module MX
|
||||
|
||||
using MXNet, Flow, ..Flux
|
||||
|
||||
export mxnet
|
||||
|
||||
include("graph.jl")
|
||||
include("model.jl")
|
||||
|
||||
end
|
Loading…
Reference in New Issue