basic MXNet output
This commit is contained in:
parent
0496ce6bda
commit
cab43611e3
|
@ -21,4 +21,6 @@ include("layers/dense.jl")
|
|||
include("layers/sequence.jl")
|
||||
include("utils.jl")
|
||||
|
||||
include("backend/mxnet/mxnet.jl")
|
||||
|
||||
end # module
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
cvalue(x) = x
|
||||
cvalue(c::Constant) = c.value
|
||||
cvalue(v::Vertex) = cvalue(value(v))
|
||||
|
||||
graph(vars, model, args...) = node(model, args...)
|
||||
|
||||
graph(vars, x::mx.SymbolicNode) = x
|
||||
|
||||
# TODO: detect parameters used more than once
|
||||
function graph(vars, value::AArray)
|
||||
id = gensym()
|
||||
vars[id] = value
|
||||
return mx.Variable(id)
|
||||
end
|
||||
|
||||
function graph(vars, model::Model, args...)
|
||||
g = Flux.graph(model)
|
||||
g = Flow.mapconst(g) do x
|
||||
!isa(x, Flux.Parameter) ? x :
|
||||
isa(x.name, Integer) ? args[x.name] : getfield(model, x.name)
|
||||
end
|
||||
postwalk(g) do v
|
||||
vertex(graph(vars, cvalue(v), cvalue.(inputs(v))...))
|
||||
end |> value
|
||||
end
|
||||
|
||||
# Built-in implemenations
|
||||
|
||||
node(::typeof(*), args...) = mx.dot(args...)
|
||||
node(::typeof(+), args...) = mx.broadcast_plus(args...)
|
|
@ -0,0 +1,21 @@
|
|||
type MXModel
|
||||
model::Any
|
||||
params::Dict{Symbol,Any}
|
||||
exec::mx.Executor
|
||||
end
|
||||
|
||||
mxdims(dims::NTuple) =
|
||||
length(dims) == 1 ? (1, dims...) : reverse(dims)
|
||||
|
||||
function mxargs(args)
|
||||
map(args) do kv
|
||||
arg, value = kv
|
||||
arg => mx.zeros(mxdims(size(value)))
|
||||
end
|
||||
end
|
||||
|
||||
function mxnet(model::Model, input)
|
||||
vars = Dict{Symbol,Any}(:input => mx.zeros(mxdims(input)))
|
||||
node = graph(vars, model, mx.Variable(:input))
|
||||
MXModel(model, vars, mx.bind(node, args = mxargs(vars), grad_req = mx.GRAD_NOP))
|
||||
end
|
|
@ -0,0 +1,12 @@
|
|||
module MX
|
||||
|
||||
using MXNet, Flow, ..Flux
|
||||
|
||||
include("graph.jl")
|
||||
include("model.jl")
|
||||
|
||||
# d = Dense(20, 10)
|
||||
|
||||
# model = mxnet(d, (1,20))
|
||||
|
||||
end
|
|
@ -1,3 +1,5 @@
|
|||
export Dense
|
||||
|
||||
@model type Dense
|
||||
W
|
||||
b
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
export onehot, onecold
|
||||
export AArray, onehot, onecold
|
||||
|
||||
const AArray = AbstractArray
|
||||
|
||||
onehot(label, labels) = [i == label for i in labels]
|
||||
onecold(pred, labels = 1:length(pred)) = labels[findfirst(pred, maximum(pred))]
|
||||
|
|
Loading…
Reference in New Issue