diff --git a/src/layers/activation.jl b/src/layers/activation.jl index f3324aa5..26ab89f6 100644 --- a/src/layers/activation.jl +++ b/src/layers/activation.jl @@ -8,11 +8,9 @@ back!(::typeof(σ), Δ, x) = Δ .* σ(x).*(1.-σ(x)) relu(x) = max(0, x) back!(::typeof(relu), Δ, x) = Δ .* (x .> 0) -# TODO: correct behaviour with batches -softmax(xs) = exp.(xs) ./ sum(exp.(xs)) +softmax(xs) = exp.(xs) ./ sum(exp.(xs), 2) -# TODO: correct behaviour with batches -flatten(xs) = reshape(xs, length(xs)) +flatten(xs) = reshape(xs, size(xs, 1), :) infer(::typeof(softmax), x) = x infer(::typeof(tanh), x) = x