diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index ff1cbc39..b8ce3c7d 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -4,10 +4,20 @@ using NNlib: logsoftmax, logσ mse(ŷ, y) = sum((ŷ .- y).^2) * 1 // length(y) -function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1) - -sum(y .* log.(ŷ) .* weight) * 1 // size(y, 2) +function _crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat, weight::Nothing) + return -sum(y .* log.(ŷ)) * 1 // size(y, 2) end +function _crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat, weight::Number) + return -sum(y .* log.(ŷ)) .* weight * 1 // size(y, 2) +end + +function _crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat, weight::AbstractVector) + return -sum(y .* log.(ŷ) .* weight) * 1 // size(y, 2) +end + +crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight=nothing) = _crossentropy(ŷ, y, weight) + function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1) return -sum(y .* logsoftmax(logŷ) .* weight) * 1 // size(y, 2) end diff --git a/test/cuda/cuda.jl b/test/cuda/cuda.jl index 59bc7f50..9bafe44a 100644 --- a/test/cuda/cuda.jl +++ b/test/cuda/cuda.jl @@ -28,6 +28,8 @@ cm = gpu(m) x = [1,2,3] cx = gpu(x) @test Flux.crossentropy(x,x) ≈ Flux.crossentropy(cx,cx) +@test Flux.crossentropy(x,x, weight=1.0) ≈ Flux.crossentropy(cx,cx, weight=1.0) +@test Flux.crossentropy(x,x, weight=[1.0;2.0;3.0]) ≈ Flux.crossentropy(cx,cx, weight=cu([1.0;2.0;3.0])) xs = rand(5, 5) ys = Flux.onehotbatch(1:5,1:5)