diff --git a/src/Flux.jl b/src/Flux.jl index 8ad4d1f9..88b2108e 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -13,7 +13,7 @@ export Chain, Dense, RNN, LSTM, GRU, Conv2D, param, params, mapleaves using NNlib -export σ, sigmoid, relu, leakyrelu, elu, swish, softmax, logsoftmax, +export σ, sigmoid, logσ, logsigmoid, relu, leakyrelu, elu, swish, softmax, logsoftmax, conv2d, maxpool2d, avgpool2d include("tracker/Tracker.jl") diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index 63c40cb8..34683fbf 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -1,4 +1,4 @@ -using NNlib: log_fast +using NNlib: log_fast, logsoftmax, logσ # Cost functions @@ -10,12 +10,37 @@ end @deprecate logloss(x, y) crossentropy(x, y) -function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat) - logŷ = logŷ .- maximum(logŷ, 1) - ypred = logŷ .- log_fast.(sum(exp.(logŷ), 1)) - -sum(y .* ypred) / size(y, 2) +function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1) + return -sum(y .* logsoftmax(logŷ) .* weight) / size(y, 2) end +""" + binarycrossentropy(ŷ, y) + +Return `-y*log(ŷ) - (1-y)*log(1-ŷ)`. + + julia> binarycrossentropy.(σ.([-1.1491, 0.8619, 0.3127]), [1, 1, 0.]) + 3-element Array{Float64,1}: + 1.4244 + 0.352317 + 0.86167 +""" +binarycrossentropy(ŷ, y) = -y*log_fast(ŷ) - (1 - y)*log_fast(1 - ŷ) + +""" + logitbinarycrossentropy(logŷ, y) + +`logitbinarycrossentropy(logŷ, y)` is mathematically equivalent to `binarycrossentropy(σ(logŷ), y)` +but it is more numerically stable. + + julia> logitbinarycrossentropy.([-1.1491, 0.8619, 0.3127], [1, 1, 0.]) + 3-element Array{Float64,1}: + 1.4244 + 0.352317 + 0.86167 +""" +logitbinarycrossentropy(logŷ, y) = (1 - y)*logŷ - logσ(logŷ) + """ normalise(x::AbstractVecOrMat) diff --git a/test/layers/stateless.jl b/test/layers/stateless.jl index 23304eb1..ecfa7014 100644 --- a/test/layers/stateless.jl +++ b/test/layers/stateless.jl @@ -1,26 +1,49 @@ -using Flux: onehotbatch, mse, crossentropy +using Base.Test +using Flux: onehotbatch, mse, crossentropy, logitcrossentropy, + σ, binarycrossentropy, logitbinarycrossentropy @testset "losses" begin # First, regression-style y's y = [1, 1, 0, 0] - y_hat = [.9, .1, .1, .9] + ŷ = [.9, .1, .1, .9] @testset "mse" begin - @test mse(y_hat, y) ≈ (.1^2 + .9^2)/2 + @test mse(ŷ, y) ≈ (.1^2 + .9^2)/2 end # Now onehot y's y = onehotbatch([1, 1, 0, 0], 0:1) - y_hat = [.1 .9; .9 .1; .9 .1; .1 .9]' - y_logloss = 1.203972804325936 + ŷ = [.1 .9; .9 .1; .9 .1; .1 .9]' + v = log(.1 / .9) + logŷ = [v 0.0; 0.0 v; 0.0 v; v 0.0]' + lossvalue = 1.203972804325936 @testset "crossentropy" begin - @test crossentropy(y_hat, y) ≈ y_logloss + @test crossentropy(ŷ, y) ≈ lossvalue + end + + @testset "logitcrossentropy" begin + @test logitcrossentropy(logŷ, y) ≈ lossvalue end @testset "weighted_crossentropy" begin - @test crossentropy(y_hat, y, weight = ones(2)) ≈ y_logloss - @test crossentropy(y_hat, y, weight = [.5, .5]) ≈ y_logloss/2 - @test crossentropy(y_hat, y, weight = [2, .5]) ≈ 1.5049660054074199 + @test crossentropy(ŷ, y, weight = ones(2)) ≈ lossvalue + @test crossentropy(ŷ, y, weight = [.5, .5]) ≈ lossvalue/2 + @test crossentropy(ŷ, y, weight = [2, .5]) ≈ 1.5049660054074199 + end + + @testset "weighted_logitcrossentropy" begin + @test logitcrossentropy(logŷ, y, weight = ones(2)) ≈ lossvalue + @test logitcrossentropy(logŷ, y, weight = [.5, .5]) ≈ lossvalue/2 + @test logitcrossentropy(logŷ, y, weight = [2, .5]) ≈ 1.5049660054074199 + end + + logŷ, y = randn(3), rand(3) + @testset "binarycrossentropy" begin + @test binarycrossentropy.(σ.(logŷ), y) ≈ -y.*log.(σ.(logŷ)) - (1 - y).*log.(1 - σ.(logŷ)) + end + + @testset "logitbinarycrossentropy" begin + @test logitbinarycrossentropy.(logŷ, y) ≈ binarycrossentropy.(σ.(logŷ), y) end end diff --git a/test/tracker.jl b/test/tracker.jl index 36b0fad1..12d018e1 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -9,6 +9,8 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...) @test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2) @test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2) +@test gradtest((x, W, b) -> logσ.(W*x .+ b), 5, (2,5), 2) +@test gradtest((x, W, b) -> logσ.(W*x .+ b), (5,3), (2,5), 2) @test gradtest((w, x) -> w'*x, randn(10, 2), randn(10)) @test gradtest((w, x) -> w*x', randn(5,5), randn(5,5))