some attempts to get mxnet working

This commit is contained in:
Mike J Innes 2016-09-26 21:44:53 +01:00
parent 330a5e785a
commit df38a89d9a
2 changed files with 10 additions and 10 deletions

View File

@ -7,7 +7,7 @@ graph(vars, model, args...) = node(model, args...)
graph(vars, x::mx.SymbolicNode) = x
# TODO: detect parameters used more than once
function graph{T<:AArray}(vars::Associative, p::Flux.Param{T})
function graph{T<:AArray}(vars, p::Flux.Param{T})
value = p.x
id = gensym()
vars[id] = value
@ -68,9 +68,8 @@ graph(vars, p::MaxPool, x) =
stride = p.stride)
# TODO: fix the initialisation issue
graph(vars::Void, d::Dense, x) =
graph(vars, d::Dense, x) =
mx.FullyConnected(data = x,
num_hidden = size(d.W.x, 1),
# weight = graph(vars, d.W),
# bias = graph(vars, d.b)
)
weight = graph(vars, d.W),
bias = graph(vars, d.b))

View File

@ -10,8 +10,7 @@ end
Base.show(io::IO, m::MXModel) =
print(io, "MXModel($(m.model))")
mxdims(dims::NTuple) =
length(dims) == 1 ? (1, dims...) : reverse(dims)
mxdims(dims::NTuple) = reverse(dims)
mxdims(n::Integer) = mxdims((n,))
@ -29,7 +28,7 @@ ndzero!(xs::mx.NDArray) = copy!(xs, mx.zeros(size(xs)))
function mxargs(args)
map(args) do kv
arg, value = kv
arg => mx.zeros(mxdims(size(value)))
arg => tond(value)
end
end
@ -91,6 +90,8 @@ end
function mx.FeedForward(model::Model; input = :data, label = :softmax, context = mx.cpu())
model = rewrite_softmax(model, label)
node, _ = mxgraph(model, input, vars = false)
return mx.FeedForward(node, context = context)
node, vars = mxgraph(model, input)
ff = mx.FeedForward(node, context = context)
ff.arg_params = mxargs(vars)
return ff
end