remove old shape hacks

This commit is contained in:
Mike J Innes 2017-01-28 23:07:02 +05:30
parent 3fa8a55e41
commit bb70f401be
2 changed files with 5 additions and 9 deletions

View File

@ -11,7 +11,7 @@ node(x::mx.SymbolicNode) = x
graph(::typeof(tuple), args...) = (args...,)
graph(s::Split, t::Tuple) = t[s.n]
graph(::typeof(*), args...) = mx.dot(reverse(args)...)
graph(::typeof(*), args...) = mx.dot(args...)
graph(::typeof(+), args...) = mx.broadcast_plus(args...)
graph(::typeof(σ), x) = mx.Activation(data = x, act_type = :sigmoid)
graph(::typeof(relu), x) = mx.Activation(data = x, act_type=:relu)

View File

@ -7,18 +7,14 @@ type MXModel <: Model
exec::mx.Executor
end
mxdims(dims::NTuple) = reverse(dims)
mxdims(n::Integer) = mxdims((n,))
function tond!(nd::mx.NDArray, xs::AArray)
mx.copy_ignore_shape!(nd, xs')
mx.copy_ignore_shape!(nd, xs)
nd
end
tond(xs::AArray) = tond!(mx.zeros(mxdims(size(xs))), xs)
tond(xs::AArray) = copy!(mx.zeros(size(xs)), xs)
fromnd(xs::mx.NDArray) = copy(xs)'
fromnd(xs::mx.NDArray) = copy(xs)
ndzero!(xs::mx.NDArray) = copy!(xs, mx.zeros(size(xs)))
@ -45,7 +41,7 @@ end
function mxnet(model::Model, input)
params, stacks, node = tograph(model, mx.Variable(:input))
args = merge(mxargs(params), Dict(:input => mx.zeros(mxdims(input))))
args = merge(mxargs(params), Dict(:input => mx.zeros(input)))
grads = mxgrads(args)
model = MXModel(model, params, grads,
mx.bind(node, args = args,