Register back! for logsigmoid and implement (logit)binarycrossentropy

This commit is contained in:
boathit 2018-02-06 19:32:46 +08:00
parent f9be72f545
commit 6e65789828
5 changed files with 72 additions and 16 deletions

View File

@ -13,7 +13,7 @@ export Chain, Dense, RNN, LSTM, GRU, Conv2D,
param, params, mapleaves param, params, mapleaves
using NNlib using NNlib
export σ, sigmoid, relu, leakyrelu, elu, swish, softmax, logsoftmax, export σ, sigmoid, logσ, logsigmoid, relu, leakyrelu, elu, swish, softmax, logsoftmax,
conv2d, maxpool2d, avgpool2d conv2d, maxpool2d, avgpool2d
include("tracker/Tracker.jl") include("tracker/Tracker.jl")
@ -36,4 +36,6 @@ include("layers/normalisation.jl")
include("data/Data.jl") include("data/Data.jl")
@require CuArrays include("cuda/cuda.jl")
end # module end # module

View File

@ -1,4 +1,4 @@
using NNlib: log_fast using NNlib: log_fast, logsoftmax, logσ
# Cost functions # Cost functions
@ -10,12 +10,37 @@ end
@deprecate logloss(x, y) crossentropy(x, y) @deprecate logloss(x, y) crossentropy(x, y)
function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat) function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
logŷ = logŷ .- maximum(logŷ, 1) return -sum(y .* logsoftmax(logŷ) .* weight) / size(y, 2)
ypred = logŷ .- log_fast.(sum(exp.(logŷ), 1))
-sum(y .* ypred) / size(y, 2)
end end
"""
binarycrossentropy(, y)
Return `-y*log(ŷ) - (1-y)*log(1-ŷ)`.
julia> binarycrossentropy.(σ.([-1.1491, 0.8619, 0.3127]), [1, 1, 0.])
3-element Array{Float64,1}:
1.4244
0.352317
0.86167
"""
binarycrossentropy(, y) = -y*log_fast() - (1 - y)*log_fast(1 - )
"""
logitbinarycrossentropy(logŷ, y)
`logitbinarycrossentropy(logŷ, y)` is mathematically equivalent to `binarycrossentropy(σ(logŷ), y)`
but it is more numerically stable.
julia> logitbinarycrossentropy.([-1.1491, 0.8619, 0.3127], [1, 1, 0.])
3-element Array{Float64,1}:
1.4244
0.352317
0.86167
"""
logitbinarycrossentropy(logŷ, y) = (1 - y)*logŷ - logσ(logŷ)
""" """
normalise(x::AbstractVecOrMat) normalise(x::AbstractVecOrMat)

View File

@ -139,7 +139,7 @@ end
# NNlib # NNlib
using NNlib using NNlib
import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv2d, pool import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, logσ, ∇logσ, conv2d, pool
softmax(xs::TrackedArray) = TrackedArray(Call(softmax, xs)) softmax(xs::TrackedArray) = TrackedArray(Call(softmax, xs))
@ -149,6 +149,10 @@ logsoftmax(xs::TrackedArray) = TrackedArray(Call(logsoftmax, xs))
back(::typeof(logsoftmax), Δ, xs) = @back(xs, ∇logsoftmax(Δ, data(xs))) back(::typeof(logsoftmax), Δ, xs) = @back(xs, ∇logsoftmax(Δ, data(xs)))
logσ(xs::TrackedArray) = TrackedArray(Call(logσ, xs))
back(::typeof(logσ), Δ, xs) = @back(xs, ∇logσ(Δ, data(xs)))
# TODO: can store kwargs efficiently in namedtuples # TODO: can store kwargs efficiently in namedtuples
_conv2d(x, w, stride, pad) = conv2d(x, w, stride = stride, padding = pad) _conv2d(x, w, stride, pad) = conv2d(x, w, stride = stride, padding = pad)

View File

@ -1,26 +1,49 @@
using Flux: onehotbatch, mse, crossentropy using Base.Test
using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
σ, binarycrossentropy, logitbinarycrossentropy
@testset "losses" begin @testset "losses" begin
# First, regression-style y's # First, regression-style y's
y = [1, 1, 0, 0] y = [1, 1, 0, 0]
y_hat = [.9, .1, .1, .9] ŷ = [.9, .1, .1, .9]
@testset "mse" begin @testset "mse" begin
@test mse(y_hat, y) (.1^2 + .9^2)/2 @test mse(ŷ, y) (.1^2 + .9^2)/2
end end
# Now onehot y's # Now onehot y's
y = onehotbatch([1, 1, 0, 0], 0:1) y = onehotbatch([1, 1, 0, 0], 0:1)
y_hat = [.1 .9; .9 .1; .9 .1; .1 .9]' ŷ = [.1 .9; .9 .1; .9 .1; .1 .9]'
y_logloss = 1.203972804325936 v = log(.1 / .9)
logŷ = [v 0.0; 0.0 v; 0.0 v; v 0.0]'
lossvalue = 1.203972804325936
@testset "crossentropy" begin @testset "crossentropy" begin
@test crossentropy(y_hat, y) y_logloss @test crossentropy(ŷ, y) lossvalue
end
@testset "logitcrossentropy" begin
@test logitcrossentropy(logŷ, y) lossvalue
end end
@testset "weighted_crossentropy" begin @testset "weighted_crossentropy" begin
@test crossentropy(y_hat, y, weight = ones(2)) y_logloss @test crossentropy(ŷ, y, weight = ones(2)) lossvalue
@test crossentropy(y_hat, y, weight = [.5, .5]) y_logloss/2 @test crossentropy(ŷ, y, weight = [.5, .5]) lossvalue/2
@test crossentropy(y_hat, y, weight = [2, .5]) 1.5049660054074199 @test crossentropy(ŷ, y, weight = [2, .5]) 1.5049660054074199
end
@testset "weighted_logitcrossentropy" begin
@test logitcrossentropy(logŷ, y, weight = ones(2)) lossvalue
@test logitcrossentropy(logŷ, y, weight = [.5, .5]) lossvalue/2
@test logitcrossentropy(logŷ, y, weight = [2, .5]) 1.5049660054074199
end
logŷ, y = randn(3), rand(3)
@testset "binarycrossentropy" begin
@test binarycrossentropy.(σ.(logŷ), y) -y.*log.(σ.(logŷ)) - (1 - y).*log.(1 - σ.(logŷ))
end
@testset "logitbinarycrossentropy" begin
@test logitbinarycrossentropy.(logŷ, y) binarycrossentropy.(σ.(logŷ), y)
end end
end end

View File

@ -9,6 +9,8 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2) @test gradtest((x, W, b) -> σ.(W*x .+ b), 5, (2,5), 2)
@test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2) @test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2)
@test gradtest((x, W, b) -> logσ.(W*x .+ b), 5, (2,5), 2)
@test gradtest((x, W, b) -> logσ.(W*x .+ b), (5,3), (2,5), 2)
@test gradtest((w, x) -> w'*x, randn(10, 2), randn(10)) @test gradtest((w, x) -> w'*x, randn(10, 2), randn(10))
@test gradtest((w, x) -> w*x', randn(5,5), randn(5,5)) @test gradtest((w, x) -> w*x', randn(5,5), randn(5,5))