Merge pull request #165 from boathit/master
Register back! for logsigmoid and implement (logit)binarycrossentropy
This commit is contained in:
commit
d12120207d
@ -13,7 +13,7 @@ export Chain, Dense, RNN, LSTM, GRU, Conv2D,
|
|||||||
param, params, mapleaves
|
param, params, mapleaves
|
||||||
|
|
||||||
using NNlib
|
using NNlib
|
||||||
export σ, sigmoid, relu, leakyrelu, elu, swish, softmax, logsoftmax,
|
export σ, sigmoid, logσ, logsigmoid, relu, leakyrelu, elu, swish, softmax, logsoftmax,
|
||||||
conv2d, maxpool2d, avgpool2d
|
conv2d, maxpool2d, avgpool2d
|
||||||
|
|
||||||
include("tracker/Tracker.jl")
|
include("tracker/Tracker.jl")
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
using NNlib: log_fast
|
using NNlib: log_fast, logsoftmax, logσ
|
||||||
|
|
||||||
# Cost functions
|
# Cost functions
|
||||||
|
|
||||||
@ -10,12 +10,37 @@ end
|
|||||||
|
|
||||||
@deprecate logloss(x, y) crossentropy(x, y)
|
@deprecate logloss(x, y) crossentropy(x, y)
|
||||||
|
|
||||||
function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat)
|
function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
|
||||||
logŷ = logŷ .- maximum(logŷ, 1)
|
return -sum(y .* logsoftmax(logŷ) .* weight) / size(y, 2)
|
||||||
ypred = logŷ .- log_fast.(sum(exp.(logŷ), 1))
|
|
||||||
-sum(y .* ypred) / size(y, 2)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
|
"""
|
||||||
|
binarycrossentropy(ŷ, y)
|
||||||
|
|
||||||
|
Return `-y*log(ŷ) - (1-y)*log(1-ŷ)`.
|
||||||
|
|
||||||
|
julia> binarycrossentropy.(σ.([-1.1491, 0.8619, 0.3127]), [1, 1, 0.])
|
||||||
|
3-element Array{Float64,1}:
|
||||||
|
1.4244
|
||||||
|
0.352317
|
||||||
|
0.86167
|
||||||
|
"""
|
||||||
|
binarycrossentropy(ŷ, y) = -y*log_fast(ŷ) - (1 - y)*log_fast(1 - ŷ)
|
||||||
|
|
||||||
|
"""
|
||||||
|
logitbinarycrossentropy(logŷ, y)
|
||||||
|
|
||||||
|
`logitbinarycrossentropy(logŷ, y)` is mathematically equivalent to `binarycrossentropy(σ(logŷ), y)`
|
||||||
|
but it is more numerically stable.
|
||||||
|
|
||||||
|
julia> logitbinarycrossentropy.([-1.1491, 0.8619, 0.3127], [1, 1, 0.])
|
||||||
|
3-element Array{Float64,1}:
|
||||||
|
1.4244
|
||||||
|
0.352317
|
||||||
|
0.86167
|
||||||
|
"""
|
||||||
|
logitbinarycrossentropy(logŷ, y) = (1 - y)*logŷ - logσ(logŷ)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
normalise(x::AbstractVecOrMat)
|
normalise(x::AbstractVecOrMat)
|
||||||
|
|
||||||
|
@ -1,26 +1,49 @@
|
|||||||
using Flux: onehotbatch, mse, crossentropy
|
using Base.Test
|
||||||
|
using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
|
||||||
|
σ, binarycrossentropy, logitbinarycrossentropy
|
||||||
|
|
||||||
@testset "losses" begin
|
@testset "losses" begin
|
||||||
# First, regression-style y's
|
# First, regression-style y's
|
||||||
y = [1, 1, 0, 0]
|
y = [1, 1, 0, 0]
|
||||||
y_hat = [.9, .1, .1, .9]
|
ŷ = [.9, .1, .1, .9]
|
||||||
|
|
||||||
@testset "mse" begin
|
@testset "mse" begin
|
||||||
@test mse(y_hat, y) ≈ (.1^2 + .9^2)/2
|
@test mse(ŷ, y) ≈ (.1^2 + .9^2)/2
|
||||||
end
|
end
|
||||||
|
|
||||||
# Now onehot y's
|
# Now onehot y's
|
||||||
y = onehotbatch([1, 1, 0, 0], 0:1)
|
y = onehotbatch([1, 1, 0, 0], 0:1)
|
||||||
y_hat = [.1 .9; .9 .1; .9 .1; .1 .9]'
|
ŷ = [.1 .9; .9 .1; .9 .1; .1 .9]'
|
||||||
y_logloss = 1.203972804325936
|
v = log(.1 / .9)
|
||||||
|
logŷ = [v 0.0; 0.0 v; 0.0 v; v 0.0]'
|
||||||
|
lossvalue = 1.203972804325936
|
||||||
|
|
||||||
@testset "crossentropy" begin
|
@testset "crossentropy" begin
|
||||||
@test crossentropy(y_hat, y) ≈ y_logloss
|
@test crossentropy(ŷ, y) ≈ lossvalue
|
||||||
|
end
|
||||||
|
|
||||||
|
@testset "logitcrossentropy" begin
|
||||||
|
@test logitcrossentropy(logŷ, y) ≈ lossvalue
|
||||||
end
|
end
|
||||||
|
|
||||||
@testset "weighted_crossentropy" begin
|
@testset "weighted_crossentropy" begin
|
||||||
@test crossentropy(y_hat, y, weight = ones(2)) ≈ y_logloss
|
@test crossentropy(ŷ, y, weight = ones(2)) ≈ lossvalue
|
||||||
@test crossentropy(y_hat, y, weight = [.5, .5]) ≈ y_logloss/2
|
@test crossentropy(ŷ, y, weight = [.5, .5]) ≈ lossvalue/2
|
||||||
@test crossentropy(y_hat, y, weight = [2, .5]) ≈ 1.5049660054074199
|
@test crossentropy(ŷ, y, weight = [2, .5]) ≈ 1.5049660054074199
|
||||||
|
end
|
||||||
|
|
||||||
|
@testset "weighted_logitcrossentropy" begin
|
||||||
|
@test logitcrossentropy(logŷ, y, weight = ones(2)) ≈ lossvalue
|
||||||
|
@test logitcrossentropy(logŷ, y, weight = [.5, .5]) ≈ lossvalue/2
|
||||||
|
@test logitcrossentropy(logŷ, y, weight = [2, .5]) ≈ 1.5049660054074199
|
||||||
|
end
|
||||||
|
|
||||||
|
logŷ, y = randn(3), rand(3)
|
||||||
|
@testset "binarycrossentropy" begin
|
||||||
|
@test binarycrossentropy.(σ.(logŷ), y) ≈ -y.*log.(σ.(logŷ)) - (1 - y).*log.(1 - σ.(logŷ))
|
||||||
|
end
|
||||||
|
|
||||||
|
@testset "logitbinarycrossentropy" begin
|
||||||
|
@test logitbinarycrossentropy.(logŷ, y) ≈ binarycrossentropy.(σ.(logŷ), y)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -9,6 +9,8 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
|||||||
|
|
||||||
@test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2)
|
@test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2)
|
||||||
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
|
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
|
||||||
|
@test gradtest((x, W, b) -> logσ.(W*x .+ b), 5, (2,5), 2)
|
||||||
|
@test gradtest((x, W, b) -> logσ.(W*x .+ b), (5,3), (2,5), 2)
|
||||||
|
|
||||||
@test gradtest((w, x) -> w'*x, randn(10, 2), randn(10))
|
@test gradtest((w, x) -> w'*x, randn(10, 2), randn(10))
|
||||||
@test gradtest((w, x) -> w*x', randn(5,5), randn(5,5))
|
@test gradtest((w, x) -> w*x', randn(5,5), randn(5,5))
|
||||||
|
Loading…
Reference in New Issue
Block a user