diff --git a/src/backend/mxnet/graph.jl b/src/backend/mxnet/graph.jl index 8e56b2c7..792987b5 100644 --- a/src/backend/mxnet/graph.jl +++ b/src/backend/mxnet/graph.jl @@ -20,10 +20,10 @@ node(x::mx.SymbolicNode) = x graph(::typeof(tuple), args...) = (args...,) graph(::typeof(+), args...) = mx.broadcast_plus(args...) graph(::typeof(*), x, W) = mx.dot(W, x) # Adjustments for batching -graph(::typeof(σ), x) = mx.Activation(data = x, act_type = :sigmoid) -graph(::typeof(relu), x) = mx.Activation(data = x, act_type = :relu) -graph(::typeof(tanh), x) = mx.Activation(data = x, act_type = :tanh) -graph(::typeof(flatten), x) = mx.Flatten(data = x) +graph(::typeof(σ), x) = mx.Activation(x, act_type = :sigmoid) +graph(::typeof(relu), x) = mx.Activation(x, act_type = :relu) +graph(::typeof(tanh), x) = mx.Activation(x, act_type = :tanh) +graph(::typeof(flatten), x) = mx.Flatten(x) graph(::typeof(softmax), xs) = mx.broadcast_div(exp(xs), mx.Reshape(mx.sum(exp(xs)), shape = (1,1))) @@ -36,20 +36,20 @@ graph(::Input, x) = x graph(ctx::Context, d::Affine, x) = !ctx[:feedforward] ? invoke(graph, (Context, Any, typeof(x)), ctx, d, x) : register(ctx, - mx.FullyConnected(data = x, + mx.FullyConnected(x, num_hidden = size(d.W.x, 2), weight = var(ctx, d.W, size(d.W)), bias = var(ctx, d.b, size(d.b, 2)))) # TODO: use actual params} graph(ctx::Context, c::Conv2D, x) = - mx.Convolution(data = x, + mx.Convolution(x, kernel = size(c.filter, 1, 2), num_filter = size(c.filter, 4), stride = c.stride) graph(ctx::Context, p::MaxPool, x) = - mx.Pooling(data = x, + mx.Pooling(x, pool_type = :max, kernel = p.size, stride = p.stride) diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl index 9b371e9d..4ae1f93f 100644 --- a/src/backend/mxnet/model.jl +++ b/src/backend/mxnet/model.jl @@ -83,7 +83,7 @@ type SoftmaxOutput name::Symbol end -graph(s::SoftmaxOutput, xs) = mx.SoftmaxOutput(data = xs, name = s.name) +graph(s::SoftmaxOutput, xs) = mx.SoftmaxOutput(xs, name = s.name) function rewrite_softmax(model, name) model == softmax && return SoftmaxOutput(name)