update for mxnet api
This commit is contained in:
parent
4a9517b23d
commit
616425554d
@ -20,10 +20,10 @@ node(x::mx.SymbolicNode) = x
|
||||
graph(::typeof(tuple), args...) = (args...,)
|
||||
graph(::typeof(+), args...) = mx.broadcast_plus(args...)
|
||||
graph(::typeof(*), x, W) = mx.dot(W, x) # Adjustments for batching
|
||||
graph(::typeof(σ), x) = mx.Activation(data = x, act_type = :sigmoid)
|
||||
graph(::typeof(relu), x) = mx.Activation(data = x, act_type = :relu)
|
||||
graph(::typeof(tanh), x) = mx.Activation(data = x, act_type = :tanh)
|
||||
graph(::typeof(flatten), x) = mx.Flatten(data = x)
|
||||
graph(::typeof(σ), x) = mx.Activation(x, act_type = :sigmoid)
|
||||
graph(::typeof(relu), x) = mx.Activation(x, act_type = :relu)
|
||||
graph(::typeof(tanh), x) = mx.Activation(x, act_type = :tanh)
|
||||
graph(::typeof(flatten), x) = mx.Flatten(x)
|
||||
|
||||
graph(::typeof(softmax), xs) =
|
||||
mx.broadcast_div(exp(xs), mx.Reshape(mx.sum(exp(xs)), shape = (1,1)))
|
||||
@ -36,20 +36,20 @@ graph(::Input, x) = x
|
||||
graph(ctx::Context, d::Affine, x) =
|
||||
!ctx[:feedforward] ? invoke(graph, (Context, Any, typeof(x)), ctx, d, x) :
|
||||
register(ctx,
|
||||
mx.FullyConnected(data = x,
|
||||
mx.FullyConnected(x,
|
||||
num_hidden = size(d.W.x, 2),
|
||||
weight = var(ctx, d.W, size(d.W)),
|
||||
bias = var(ctx, d.b, size(d.b, 2))))
|
||||
|
||||
# TODO: use actual params}
|
||||
graph(ctx::Context, c::Conv2D, x) =
|
||||
mx.Convolution(data = x,
|
||||
mx.Convolution(x,
|
||||
kernel = size(c.filter, 1, 2),
|
||||
num_filter = size(c.filter, 4),
|
||||
stride = c.stride)
|
||||
|
||||
graph(ctx::Context, p::MaxPool, x) =
|
||||
mx.Pooling(data = x,
|
||||
mx.Pooling(x,
|
||||
pool_type = :max,
|
||||
kernel = p.size,
|
||||
stride = p.stride)
|
||||
|
@ -83,7 +83,7 @@ type SoftmaxOutput
|
||||
name::Symbol
|
||||
end
|
||||
|
||||
graph(s::SoftmaxOutput, xs) = mx.SoftmaxOutput(data = xs, name = s.name)
|
||||
graph(s::SoftmaxOutput, xs) = mx.SoftmaxOutput(xs, name = s.name)
|
||||
|
||||
function rewrite_softmax(model, name)
|
||||
model == softmax && return SoftmaxOutput(name)
|
||||
|
Loading…
Reference in New Issue
Block a user