From bb70f401be94489d4fb875a10e1e32af04575a90 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Sat, 28 Jan 2017 23:07:02 +0530 Subject: [PATCH] remove old shape hacks --- src/backend/mxnet/graph.jl | 2 +- src/backend/mxnet/model.jl | 12 ++++-------- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/src/backend/mxnet/graph.jl b/src/backend/mxnet/graph.jl index 183806bc..6a7db27b 100644 --- a/src/backend/mxnet/graph.jl +++ b/src/backend/mxnet/graph.jl @@ -11,7 +11,7 @@ node(x::mx.SymbolicNode) = x graph(::typeof(tuple), args...) = (args...,) graph(s::Split, t::Tuple) = t[s.n] -graph(::typeof(*), args...) = mx.dot(reverse(args)...) +graph(::typeof(*), args...) = mx.dot(args...) graph(::typeof(+), args...) = mx.broadcast_plus(args...) graph(::typeof(σ), x) = mx.Activation(data = x, act_type = :sigmoid) graph(::typeof(relu), x) = mx.Activation(data = x, act_type=:relu) diff --git a/src/backend/mxnet/model.jl b/src/backend/mxnet/model.jl index fe27d8d9..da3b62f6 100644 --- a/src/backend/mxnet/model.jl +++ b/src/backend/mxnet/model.jl @@ -7,18 +7,14 @@ type MXModel <: Model exec::mx.Executor end -mxdims(dims::NTuple) = reverse(dims) - -mxdims(n::Integer) = mxdims((n,)) - function tond!(nd::mx.NDArray, xs::AArray) - mx.copy_ignore_shape!(nd, xs') + mx.copy_ignore_shape!(nd, xs) nd end -tond(xs::AArray) = tond!(mx.zeros(mxdims(size(xs))), xs) +tond(xs::AArray) = copy!(mx.zeros(size(xs)), xs) -fromnd(xs::mx.NDArray) = copy(xs)' +fromnd(xs::mx.NDArray) = copy(xs) ndzero!(xs::mx.NDArray) = copy!(xs, mx.zeros(size(xs))) @@ -45,7 +41,7 @@ end function mxnet(model::Model, input) params, stacks, node = tograph(model, mx.Variable(:input)) - args = merge(mxargs(params), Dict(:input => mx.zeros(mxdims(input)))) + args = merge(mxargs(params), Dict(:input => mx.zeros(input))) grads = mxgrads(args) model = MXModel(model, params, grads, mx.bind(node, args = args,