diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index edbdec58..8d675735 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -4,8 +4,15 @@ using NNlib: log_fast mse(ŷ, y) = sum((ŷ .- y).^2)/length(y) -crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat) = - -sum(y .* log_fast.(ŷ)) / size(y, 2) +function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat) + return -sum(y .* log_fast.(ŷ)) / size(y, 2) +end + +function weighted_crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat, w::AbstractVecOrMat) + return -sum(y .* log_fast.(ŷ) .* w) / size(y, 2) +end + + @deprecate logloss(x, y) crossentropy(x, y) diff --git a/test/layers/stateless.jl b/test/layers/stateless.jl new file mode 100644 index 00000000..b7a42841 --- /dev/null +++ b/test/layers/stateless.jl @@ -0,0 +1,26 @@ +using Flux: onehotbatch, mse, crossentropy, weighted_crossentropy + +@testset "losses" begin + # First, regression-style y's + y = [1, 1, 0, 0] + y_hat = [.9, .1, .1, .9] + + @testset "mse" begin + @test mse(y_hat, y) ≈ (.1^2 + .9^2)/2 + end + + # Now onehot y's + y = onehotbatch([1, 1, 0, 0], 0:1) + y_hat = [.1 .9; .9 .1; .9 .1; .1 .9]' + y_logloss = 1.203972804325936 + + @testset "crossentropy" begin + @test crossentropy(y_hat, y) ≈ y_logloss + end + + @testset "weighted_crossentropy" begin + @test weighted_crossentropy(y_hat, y, ones(2)) ≈ y_logloss + @test weighted_crossentropy(y_hat, y, [.5, .5]) ≈ y_logloss/2 + @test weighted_crossentropy(y_hat, y, [2, .5]) ≈ 1.5049660054074199 + end +end diff --git a/test/runtests.jl b/test/runtests.jl index efd1a462..5c6ba549 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,5 +5,6 @@ using Flux, Base.Test include("utils.jl") include("tracker.jl") include("layers/normalisation.jl") +include("layers/stateless.jl") end