remove old shape hacks

This commit is contained in:
Mike J Innes 2017-01-28 23:07:02 +05:30
parent 3fa8a55e41
commit bb70f401be
2 changed files with 5 additions and 9 deletions

View File

@ -11,7 +11,7 @@ node(x::mx.SymbolicNode) = x
graph(::typeof(tuple), args...) = (args...,) graph(::typeof(tuple), args...) = (args...,)
graph(s::Split, t::Tuple) = t[s.n] graph(s::Split, t::Tuple) = t[s.n]
graph(::typeof(*), args...) = mx.dot(reverse(args)...) graph(::typeof(*), args...) = mx.dot(args...)
graph(::typeof(+), args...) = mx.broadcast_plus(args...) graph(::typeof(+), args...) = mx.broadcast_plus(args...)
graph(::typeof(σ), x) = mx.Activation(data = x, act_type = :sigmoid) graph(::typeof(σ), x) = mx.Activation(data = x, act_type = :sigmoid)
graph(::typeof(relu), x) = mx.Activation(data = x, act_type=:relu) graph(::typeof(relu), x) = mx.Activation(data = x, act_type=:relu)

View File

@ -7,18 +7,14 @@ type MXModel <: Model
exec::mx.Executor exec::mx.Executor
end end
mxdims(dims::NTuple) = reverse(dims)
mxdims(n::Integer) = mxdims((n,))
function tond!(nd::mx.NDArray, xs::AArray) function tond!(nd::mx.NDArray, xs::AArray)
mx.copy_ignore_shape!(nd, xs') mx.copy_ignore_shape!(nd, xs)
nd nd
end end
tond(xs::AArray) = tond!(mx.zeros(mxdims(size(xs))), xs) tond(xs::AArray) = copy!(mx.zeros(size(xs)), xs)
fromnd(xs::mx.NDArray) = copy(xs)' fromnd(xs::mx.NDArray) = copy(xs)
ndzero!(xs::mx.NDArray) = copy!(xs, mx.zeros(size(xs))) ndzero!(xs::mx.NDArray) = copy!(xs, mx.zeros(size(xs)))
@ -45,7 +41,7 @@ end
function mxnet(model::Model, input) function mxnet(model::Model, input)
params, stacks, node = tograph(model, mx.Variable(:input)) params, stacks, node = tograph(model, mx.Variable(:input))
args = merge(mxargs(params), Dict(:input => mx.zeros(mxdims(input)))) args = merge(mxargs(params), Dict(:input => mx.zeros(input)))
grads = mxgrads(args) grads = mxgrads(args)
model = MXModel(model, params, grads, model = MXModel(model, params, grads,
mx.bind(node, args = args, mx.bind(node, args = args,