From d27d59b07144cd5dd2861bc482de677c08c6fa86 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Thu, 16 Mar 2017 11:51:57 +0000 Subject: [PATCH] softmax on batches --- src/backend/mxnet/graph.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/backend/mxnet/graph.jl b/src/backend/mxnet/graph.jl index d0395ca1..17d04f09 100644 --- a/src/backend/mxnet/graph.jl +++ b/src/backend/mxnet/graph.jl @@ -26,7 +26,7 @@ graph(::typeof(tanh), x) = mx.Activation(x, act_type = :tanh) graph(::typeof(flatten), x) = mx.Flatten(x) graph(::typeof(softmax), xs) = - mx.broadcast_div(exp(xs), mx.Reshape(mx.sum(exp(xs)), shape = (1,1))) + mx.broadcast_div(exp(xs), mx.sum(exp(xs), axis = 1, keepdims=true)) graph(::typeof(cat), dim::Integer, a...) = mx.Concat(a..., dim = dim) graph(::typeof(vcat), a...) = graph(cat, 1, a...)