diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index edbdec58..63c40cb8 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -4,8 +4,9 @@ using NNlib: log_fast mse(ŷ, y) = sum((ŷ .- y).^2)/length(y) -crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat) = - -sum(y .* log_fast.(ŷ)) / size(y, 2) +function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1) + return -sum(y .* log_fast.(ŷ) .* weight) / size(y, 2) +end @deprecate logloss(x, y) crossentropy(x, y) diff --git a/test/layers/stateless.jl b/test/layers/stateless.jl new file mode 100644 index 00000000..23304eb1 --- /dev/null +++ b/test/layers/stateless.jl @@ -0,0 +1,26 @@ +using Flux: onehotbatch, mse, crossentropy + +@testset "losses" begin + # First, regression-style y's + y = [1, 1, 0, 0] + y_hat = [.9, .1, .1, .9] + + @testset "mse" begin + @test mse(y_hat, y) ≈ (.1^2 + .9^2)/2 + end + + # Now onehot y's + y = onehotbatch([1, 1, 0, 0], 0:1) + y_hat = [.1 .9; .9 .1; .9 .1; .1 .9]' + y_logloss = 1.203972804325936 + + @testset "crossentropy" begin + @test crossentropy(y_hat, y) ≈ y_logloss + end + + @testset "weighted_crossentropy" begin + @test crossentropy(y_hat, y, weight = ones(2)) ≈ y_logloss + @test crossentropy(y_hat, y, weight = [.5, .5]) ≈ y_logloss/2 + @test crossentropy(y_hat, y, weight = [2, .5]) ≈ 1.5049660054074199 + end +end diff --git a/test/runtests.jl b/test/runtests.jl index bdd1f2d0..38ddb85f 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,6 +5,7 @@ using Flux, Base.Test include("utils.jl") include("tracker.jl") include("layers/normalisation.jl") +include("layers/stateless.jl") include("optimise.jl") end