Register back! for logsigmoid and implement (logit)binarycrossentropy

This commit is contained in:
boathit 2018-02-06 19:32:46 +08:00
parent f9be72f545
commit 6e65789828
5 changed files with 72 additions and 16 deletions

View File

@ -13,7 +13,7 @@ export Chain, Dense, RNN, LSTM, GRU, Conv2D,
param, params, mapleaves
using NNlib
export σ, sigmoid, relu, leakyrelu, elu, swish, softmax, logsoftmax,
export σ, sigmoid, logσ, logsigmoid, relu, leakyrelu, elu, swish, softmax, logsoftmax,
conv2d, maxpool2d, avgpool2d
include("tracker/Tracker.jl")
@ -36,4 +36,6 @@ include("layers/normalisation.jl")
include("data/Data.jl")
@require CuArrays include("cuda/cuda.jl")
end # module

View File

@ -1,4 +1,4 @@
using NNlib: log_fast
using NNlib: log_fast, logsoftmax, logσ
# Cost functions
@ -10,12 +10,37 @@ end
@deprecate logloss(x, y) crossentropy(x, y)
function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat)
logŷ = logŷ .- maximum(logŷ, 1)
ypred = logŷ .- log_fast.(sum(exp.(logŷ), 1))
-sum(y .* ypred) / size(y, 2)
function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
return -sum(y .* logsoftmax(logŷ) .* weight) / size(y, 2)
end
"""
binarycrossentropy(, y)
Return `-y*log(ŷ) - (1-y)*log(1-ŷ)`.
julia> binarycrossentropy.(σ.([-1.1491, 0.8619, 0.3127]), [1, 1, 0.])
3-element Array{Float64,1}:
1.4244
0.352317
0.86167
"""
binarycrossentropy(, y) = -y*log_fast() - (1 - y)*log_fast(1 - )
"""
logitbinarycrossentropy(logŷ, y)
`logitbinarycrossentropy(logŷ, y)` is mathematically equivalent to `binarycrossentropy(σ(logŷ), y)`
but it is more numerically stable.
julia> logitbinarycrossentropy.([-1.1491, 0.8619, 0.3127], [1, 1, 0.])
3-element Array{Float64,1}:
1.4244
0.352317
0.86167
"""
logitbinarycrossentropy(logŷ, y) = (1 - y)*logŷ - logσ(logŷ)
"""
normalise(x::AbstractVecOrMat)

View File

@ -139,7 +139,7 @@ end
# NNlib
using NNlib
import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv2d, pool
import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, logσ, ∇logσ, conv2d, pool
softmax(xs::TrackedArray) = TrackedArray(Call(softmax, xs))
@ -149,6 +149,10 @@ logsoftmax(xs::TrackedArray) = TrackedArray(Call(logsoftmax, xs))
back(::typeof(logsoftmax), Δ, xs) = @back(xs, ∇logsoftmax(Δ, data(xs)))
logσ(xs::TrackedArray) = TrackedArray(Call(logσ, xs))
back(::typeof(logσ), Δ, xs) = @back(xs, ∇logσ(Δ, data(xs)))
# TODO: can store kwargs efficiently in namedtuples
_conv2d(x, w, stride, pad) = conv2d(x, w, stride = stride, padding = pad)

View File

@ -1,26 +1,49 @@
using Flux: onehotbatch, mse, crossentropy
using Base.Test
using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
σ, binarycrossentropy, logitbinarycrossentropy
@testset "losses" begin
# First, regression-style y's
y = [1, 1, 0, 0]
y_hat = [.9, .1, .1, .9]
ŷ = [.9, .1, .1, .9]
@testset "mse" begin
@test mse(y_hat, y) (.1^2 + .9^2)/2
@test mse(ŷ, y) (.1^2 + .9^2)/2
end
# Now onehot y's
y = onehotbatch([1, 1, 0, 0], 0:1)
y_hat = [.1 .9; .9 .1; .9 .1; .1 .9]'
y_logloss = 1.203972804325936
ŷ = [.1 .9; .9 .1; .9 .1; .1 .9]'
v = log(.1 / .9)
logŷ = [v 0.0; 0.0 v; 0.0 v; v 0.0]'
lossvalue = 1.203972804325936
@testset "crossentropy" begin
@test crossentropy(y_hat, y) y_logloss
@test crossentropy(ŷ, y) lossvalue
end
@testset "logitcrossentropy" begin
@test logitcrossentropy(logŷ, y) lossvalue
end
@testset "weighted_crossentropy" begin
@test crossentropy(y_hat, y, weight = ones(2)) y_logloss
@test crossentropy(y_hat, y, weight = [.5, .5]) y_logloss/2
@test crossentropy(y_hat, y, weight = [2, .5]) 1.5049660054074199
@test crossentropy(ŷ, y, weight = ones(2)) lossvalue
@test crossentropy(ŷ, y, weight = [.5, .5]) lossvalue/2
@test crossentropy(ŷ, y, weight = [2, .5]) 1.5049660054074199
end
@testset "weighted_logitcrossentropy" begin
@test logitcrossentropy(logŷ, y, weight = ones(2)) lossvalue
@test logitcrossentropy(logŷ, y, weight = [.5, .5]) lossvalue/2
@test logitcrossentropy(logŷ, y, weight = [2, .5]) 1.5049660054074199
end
logŷ, y = randn(3), rand(3)
@testset "binarycrossentropy" begin
@test binarycrossentropy.(σ.(logŷ), y) -y.*log.(σ.(logŷ)) - (1 - y).*log.(1 - σ.(logŷ))
end
@testset "logitbinarycrossentropy" begin
@test logitbinarycrossentropy.(logŷ, y) binarycrossentropy.(σ.(logŷ), y)
end
end

View File

@ -9,6 +9,8 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2)
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
@test gradtest((x, W, b) -> logσ.(W*x .+ b), 5, (2,5), 2)
@test gradtest((x, W, b) -> logσ.(W*x .+ b), (5,3), (2,5), 2)
@test gradtest((w, x) -> w'*x, randn(10, 2), randn(10))
@test gradtest((w, x) -> w*x', randn(5,5), randn(5,5))