softmax on batches
This commit is contained in:
parent
d1ce09211d
commit
d27d59b071
@ -26,7 +26,7 @@ graph(::typeof(tanh), x) = mx.Activation(x, act_type = :tanh)
|
||||
graph(::typeof(flatten), x) = mx.Flatten(x)
|
||||
|
||||
graph(::typeof(softmax), xs) =
|
||||
mx.broadcast_div(exp(xs), mx.Reshape(mx.sum(exp(xs)), shape = (1,1)))
|
||||
mx.broadcast_div(exp(xs), mx.sum(exp(xs), axis = 1, keepdims=true))
|
||||
|
||||
graph(::typeof(cat), dim::Integer, a...) = mx.Concat(a..., dim = dim)
|
||||
graph(::typeof(vcat), a...) = graph(cat, 1, a...)
|
||||
|
Loading…
Reference in New Issue
Block a user