diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index 8d675735..63c40cb8 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -4,16 +4,10 @@ using NNlib: log_fast mse(ŷ, y) = sum((ŷ .- y).^2)/length(y) -function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat) - return -sum(y .* log_fast.(ŷ)) / size(y, 2) +function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1) + return -sum(y .* log_fast.(ŷ) .* weight) / size(y, 2) end -function weighted_crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat, w::AbstractVecOrMat) - return -sum(y .* log_fast.(ŷ) .* w) / size(y, 2) -end - - - @deprecate logloss(x, y) crossentropy(x, y) function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat) diff --git a/test/layers/stateless.jl b/test/layers/stateless.jl index b7a42841..23304eb1 100644 --- a/test/layers/stateless.jl +++ b/test/layers/stateless.jl @@ -1,4 +1,4 @@ -using Flux: onehotbatch, mse, crossentropy, weighted_crossentropy +using Flux: onehotbatch, mse, crossentropy @testset "losses" begin # First, regression-style y's @@ -19,8 +19,8 @@ using Flux: onehotbatch, mse, crossentropy, weighted_crossentropy end @testset "weighted_crossentropy" begin - @test weighted_crossentropy(y_hat, y, ones(2)) ≈ y_logloss - @test weighted_crossentropy(y_hat, y, [.5, .5]) ≈ y_logloss/2 - @test weighted_crossentropy(y_hat, y, [2, .5]) ≈ 1.5049660054074199 + @test crossentropy(y_hat, y, weight = ones(2)) ≈ y_logloss + @test crossentropy(y_hat, y, weight = [.5, .5]) ≈ y_logloss/2 + @test crossentropy(y_hat, y, weight = [2, .5]) ≈ 1.5049660054074199 end end