diff --git a/REQUIRE b/REQUIRE index 70e314a2..b84fa07e 100644 --- a/REQUIRE +++ b/REQUIRE @@ -1 +1,2 @@ julia 0.5- +MXNet diff --git a/src/Flux.jl b/src/Flux.jl index 39e653a5..f63f5344 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -21,4 +21,6 @@ include("layers/dense.jl") include("layers/sequence.jl") include("utils.jl") +include("backend/mxnet/mxnet.jl") + end # module diff --git a/src/backend/mxnet/graph.jl b/src/backend/mxnet/graph.jl new file mode 100644 index 00000000..b9b35f2f --- /dev/null +++ b/src/backend/mxnet/graph.jl @@ -0,0 +1,30 @@ +cvalue(x) = x +cvalue(c::Constant) = c.value +cvalue(v::Vertex) = cvalue(value(v)) + +graph(vars, model, args...) = node(model, args...) + +graph(vars, x::mx.SymbolicNode) = x + +# TODO: detect parameters used more than once +function graph(vars, value::AArray) + id = gensym() + vars[id] = value + return mx.Variable(id) +end + +function graph(vars, model::Model, args...) + g = Flux.graph(model) + g = Flow.mapconst(g) do x + !isa(x, Flux.Parameter) ? x : + isa(x.name, Integer) ? args[x.name] : getfield(model, x.name) + end + postwalk(g) do v + vertex(graph(vars, cvalue(v), cvalue.(inputs(v))...)) + end |> value +end + +# Built-in implemenations + +node(::typeof(*), args...) = mx.dot(args...) +node(::typeof(+), args...) = mx.broadcast_plus(args...) diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl new file mode 100644 index 00000000..eea80ad6 --- /dev/null +++ b/src/backend/mxnet/model.jl @@ -0,0 +1,21 @@ +type MXModel + model::Any + params::Dict{Symbol,Any} + exec::mx.Executor +end + +mxdims(dims::NTuple) = + length(dims) == 1 ? (1, dims...) : reverse(dims) + +function mxargs(args) + map(args) do kv + arg, value = kv + arg => mx.zeros(mxdims(size(value))) + end +end + +function mxnet(model::Model, input) + vars = Dict{Symbol,Any}(:input => mx.zeros(mxdims(input))) + node = graph(vars, model, mx.Variable(:input)) + MXModel(model, vars, mx.bind(node, args = mxargs(vars), grad_req = mx.GRAD_NOP)) +end diff --git a/src/backend/mxnet/mxnet.jl b/src/backend/mxnet/mxnet.jl new file mode 100644 index 00000000..14e23be7 --- /dev/null +++ b/src/backend/mxnet/mxnet.jl @@ -0,0 +1,12 @@ +module MX + +using MXNet, Flow, ..Flux + +include("graph.jl") +include("model.jl") + +# d = Dense(20, 10) + +# model = mxnet(d, (1,20)) + +end diff --git a/src/layers/dense.jl b/src/layers/dense.jl index a36cf1eb..33b34b28 100644 --- a/src/layers/dense.jl +++ b/src/layers/dense.jl @@ -1,3 +1,5 @@ +export Dense + @model type Dense W b diff --git a/src/utils.jl b/src/utils.jl index 7ce83918..e5b8b815 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,4 +1,6 @@ -export onehot, onecold +export AArray, onehot, onecold + +const AArray = AbstractArray onehot(label, labels) = [i == label for i in labels] onecold(pred, labels = 1:length(pred)) = labels[findfirst(pred, maximum(pred))]