basic MXNet output

This commit is contained in:
Mike J Innes 2016-08-22 21:13:28 +01:00
parent 0496ce6bda
commit cab43611e3
7 changed files with 71 additions and 1 deletions

View File

@ -1 +1,2 @@
julia 0.5-
MXNet

View File

@ -21,4 +21,6 @@ include("layers/dense.jl")
include("layers/sequence.jl")
include("utils.jl")
include("backend/mxnet/mxnet.jl")
end # module

View File

@ -0,0 +1,30 @@
cvalue(x) = x
cvalue(c::Constant) = c.value
cvalue(v::Vertex) = cvalue(value(v))
graph(vars, model, args...) = node(model, args...)
graph(vars, x::mx.SymbolicNode) = x
# TODO: detect parameters used more than once
function graph(vars, value::AArray)
id = gensym()
vars[id] = value
return mx.Variable(id)
end
function graph(vars, model::Model, args...)
g = Flux.graph(model)
g = Flow.mapconst(g) do x
!isa(x, Flux.Parameter) ? x :
isa(x.name, Integer) ? args[x.name] : getfield(model, x.name)
end
postwalk(g) do v
vertex(graph(vars, cvalue(v), cvalue.(inputs(v))...))
end |> value
end
# Built-in implemenations
node(::typeof(*), args...) = mx.dot(args...)
node(::typeof(+), args...) = mx.broadcast_plus(args...)

View File

@ -0,0 +1,21 @@
type MXModel
model::Any
params::Dict{Symbol,Any}
exec::mx.Executor
end
mxdims(dims::NTuple) =
length(dims) == 1 ? (1, dims...) : reverse(dims)
function mxargs(args)
map(args) do kv
arg, value = kv
arg => mx.zeros(mxdims(size(value)))
end
end
function mxnet(model::Model, input)
vars = Dict{Symbol,Any}(:input => mx.zeros(mxdims(input)))
node = graph(vars, model, mx.Variable(:input))
MXModel(model, vars, mx.bind(node, args = mxargs(vars), grad_req = mx.GRAD_NOP))
end

View File

@ -0,0 +1,12 @@
module MX
using MXNet, Flow, ..Flux
include("graph.jl")
include("model.jl")
# d = Dense(20, 10)
# model = mxnet(d, (1,20))
end

View File

@ -1,3 +1,5 @@
export Dense
@model type Dense
W
b

View File

@ -1,4 +1,6 @@
export onehot, onecold
export AArray, onehot, onecold
const AArray = AbstractArray
onehot(label, labels) = [i == label for i in labels]
onecold(pred, labels = 1:length(pred)) = labels[findfirst(pred, maximum(pred))]