Merge pull request #165 from boathit/master

Register back! for logsigmoid and implement (logit)binarycrossentropy
This commit is contained in:
Mike J Innes 2018-02-13 14:56:53 +00:00 committed by GitHub
commit d12120207d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 65 additions and 15 deletions

View File

@ -13,7 +13,7 @@ export Chain, Dense, RNN, LSTM, GRU, Conv2D,
param, params, mapleaves
using NNlib
export σ, sigmoid, relu, leakyrelu, elu, swish, softmax, logsoftmax,
export σ, sigmoid, logσ, logsigmoid, relu, leakyrelu, elu, swish, softmax, logsoftmax,
conv2d, maxpool2d, avgpool2d
include("tracker/Tracker.jl")

View File

@ -1,4 +1,4 @@
using NNlib: log_fast
using NNlib: log_fast, logsoftmax, logσ
# Cost functions
@ -10,12 +10,37 @@ end
@deprecate logloss(x, y) crossentropy(x, y)
function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat)
logŷ = logŷ .- maximum(logŷ, 1)
ypred = logŷ .- log_fast.(sum(exp.(logŷ), 1))
-sum(y .* ypred) / size(y, 2)
function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
return -sum(y .* logsoftmax(logŷ) .* weight) / size(y, 2)
end
"""
binarycrossentropy(, y)
Return `-y*log(ŷ) - (1-y)*log(1-ŷ)`.
julia> binarycrossentropy.(σ.([-1.1491, 0.8619, 0.3127]), [1, 1, 0.])
3-element Array{Float64,1}:
1.4244
0.352317
0.86167
"""
binarycrossentropy(, y) = -y*log_fast() - (1 - y)*log_fast(1 - )
"""
logitbinarycrossentropy(logŷ, y)
`logitbinarycrossentropy(logŷ, y)` is mathematically equivalent to `binarycrossentropy(σ(logŷ), y)`
but it is more numerically stable.
julia> logitbinarycrossentropy.([-1.1491, 0.8619, 0.3127], [1, 1, 0.])
3-element Array{Float64,1}:
1.4244
0.352317
0.86167
"""
logitbinarycrossentropy(logŷ, y) = (1 - y)*logŷ - logσ(logŷ)
"""
normalise(x::AbstractVecOrMat)

View File

@ -1,26 +1,49 @@
using Flux: onehotbatch, mse, crossentropy
using Base.Test
using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
σ, binarycrossentropy, logitbinarycrossentropy
@testset "losses" begin
# First, regression-style y's
y = [1, 1, 0, 0]
y_hat = [.9, .1, .1, .9]
ŷ = [.9, .1, .1, .9]
@testset "mse" begin
@test mse(y_hat, y) (.1^2 + .9^2)/2
@test mse(ŷ, y) (.1^2 + .9^2)/2
end
# Now onehot y's
y = onehotbatch([1, 1, 0, 0], 0:1)
y_hat = [.1 .9; .9 .1; .9 .1; .1 .9]'
y_logloss = 1.203972804325936
ŷ = [.1 .9; .9 .1; .9 .1; .1 .9]'
v = log(.1 / .9)
logŷ = [v 0.0; 0.0 v; 0.0 v; v 0.0]'
lossvalue = 1.203972804325936
@testset "crossentropy" begin
@test crossentropy(y_hat, y) y_logloss
@test crossentropy(ŷ, y) lossvalue
end
@testset "logitcrossentropy" begin
@test logitcrossentropy(logŷ, y) lossvalue
end
@testset "weighted_crossentropy" begin
@test crossentropy(y_hat, y, weight = ones(2)) y_logloss
@test crossentropy(y_hat, y, weight = [.5, .5]) y_logloss/2
@test crossentropy(y_hat, y, weight = [2, .5]) 1.5049660054074199
@test crossentropy(ŷ, y, weight = ones(2)) lossvalue
@test crossentropy(ŷ, y, weight = [.5, .5]) lossvalue/2
@test crossentropy(ŷ, y, weight = [2, .5]) 1.5049660054074199
end
@testset "weighted_logitcrossentropy" begin
@test logitcrossentropy(logŷ, y, weight = ones(2)) lossvalue
@test logitcrossentropy(logŷ, y, weight = [.5, .5]) lossvalue/2
@test logitcrossentropy(logŷ, y, weight = [2, .5]) 1.5049660054074199
end
logŷ, y = randn(3), rand(3)
@testset "binarycrossentropy" begin
@test binarycrossentropy.(σ.(logŷ), y) -y.*log.(σ.(logŷ)) - (1 - y).*log.(1 - σ.(logŷ))
end
@testset "logitbinarycrossentropy" begin
@test logitbinarycrossentropy.(logŷ, y) binarycrossentropy.(σ.(logŷ), y)
end
end

View File

@ -9,6 +9,8 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2)
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
@test gradtest((x, W, b) -> logσ.(W*x .+ b), 5, (2,5), 2)
@test gradtest((x, W, b) -> logσ.(W*x .+ b), (5,3), (2,5), 2)
@test gradtest((w, x) -> w'*x, randn(10, 2), randn(10))
@test gradtest((w, x) -> w*x', randn(5,5), randn(5,5))