update for mxnet api

This commit is contained in:
Mike J Innes 2017-02-23 16:58:10 +00:00
parent 4a9517b23d
commit 616425554d
2 changed files with 8 additions and 8 deletions

View File

@ -20,10 +20,10 @@ node(x::mx.SymbolicNode) = x
graph(::typeof(tuple), args...) = (args...,)
graph(::typeof(+), args...) = mx.broadcast_plus(args...)
graph(::typeof(*), x, W) = mx.dot(W, x) # Adjustments for batching
graph(::typeof(σ), x) = mx.Activation(data = x, act_type = :sigmoid)
graph(::typeof(relu), x) = mx.Activation(data = x, act_type = :relu)
graph(::typeof(tanh), x) = mx.Activation(data = x, act_type = :tanh)
graph(::typeof(flatten), x) = mx.Flatten(data = x)
graph(::typeof(σ), x) = mx.Activation(x, act_type = :sigmoid)
graph(::typeof(relu), x) = mx.Activation(x, act_type = :relu)
graph(::typeof(tanh), x) = mx.Activation(x, act_type = :tanh)
graph(::typeof(flatten), x) = mx.Flatten(x)
graph(::typeof(softmax), xs) =
mx.broadcast_div(exp(xs), mx.Reshape(mx.sum(exp(xs)), shape = (1,1)))
@ -36,20 +36,20 @@ graph(::Input, x) = x
graph(ctx::Context, d::Affine, x) =
!ctx[:feedforward] ? invoke(graph, (Context, Any, typeof(x)), ctx, d, x) :
register(ctx,
mx.FullyConnected(data = x,
mx.FullyConnected(x,
num_hidden = size(d.W.x, 2),
weight = var(ctx, d.W, size(d.W)),
bias = var(ctx, d.b, size(d.b, 2))))
# TODO: use actual params}
graph(ctx::Context, c::Conv2D, x) =
mx.Convolution(data = x,
mx.Convolution(x,
kernel = size(c.filter, 1, 2),
num_filter = size(c.filter, 4),
stride = c.stride)
graph(ctx::Context, p::MaxPool, x) =
mx.Pooling(data = x,
mx.Pooling(x,
pool_type = :max,
kernel = p.size,
stride = p.stride)

View File

@ -83,7 +83,7 @@ type SoftmaxOutput
name::Symbol
end
graph(s::SoftmaxOutput, xs) = mx.SoftmaxOutput(data = xs, name = s.name)
graph(s::SoftmaxOutput, xs) = mx.SoftmaxOutput(xs, name = s.name)
function rewrite_softmax(model, name)
model == softmax && return SoftmaxOutput(name)