remove old shape hacks
This commit is contained in:
parent
3fa8a55e41
commit
bb70f401be
@ -11,7 +11,7 @@ node(x::mx.SymbolicNode) = x
|
||||
|
||||
graph(::typeof(tuple), args...) = (args...,)
|
||||
graph(s::Split, t::Tuple) = t[s.n]
|
||||
graph(::typeof(*), args...) = mx.dot(reverse(args)...)
|
||||
graph(::typeof(*), args...) = mx.dot(args...)
|
||||
graph(::typeof(+), args...) = mx.broadcast_plus(args...)
|
||||
graph(::typeof(σ), x) = mx.Activation(data = x, act_type = :sigmoid)
|
||||
graph(::typeof(relu), x) = mx.Activation(data = x, act_type=:relu)
|
||||
|
@ -7,18 +7,14 @@ type MXModel <: Model
|
||||
exec::mx.Executor
|
||||
end
|
||||
|
||||
mxdims(dims::NTuple) = reverse(dims)
|
||||
|
||||
mxdims(n::Integer) = mxdims((n,))
|
||||
|
||||
function tond!(nd::mx.NDArray, xs::AArray)
|
||||
mx.copy_ignore_shape!(nd, xs')
|
||||
mx.copy_ignore_shape!(nd, xs)
|
||||
nd
|
||||
end
|
||||
|
||||
tond(xs::AArray) = tond!(mx.zeros(mxdims(size(xs))), xs)
|
||||
tond(xs::AArray) = copy!(mx.zeros(size(xs)), xs)
|
||||
|
||||
fromnd(xs::mx.NDArray) = copy(xs)'
|
||||
fromnd(xs::mx.NDArray) = copy(xs)
|
||||
|
||||
ndzero!(xs::mx.NDArray) = copy!(xs, mx.zeros(size(xs)))
|
||||
|
||||
@ -45,7 +41,7 @@ end
|
||||
|
||||
function mxnet(model::Model, input)
|
||||
params, stacks, node = tograph(model, mx.Variable(:input))
|
||||
args = merge(mxargs(params), Dict(:input => mx.zeros(mxdims(input))))
|
||||
args = merge(mxargs(params), Dict(:input => mx.zeros(input)))
|
||||
grads = mxgrads(args)
|
||||
model = MXModel(model, params, grads,
|
||||
mx.bind(node, args = args,
|
||||
|
Loading…
Reference in New Issue
Block a user