Merge pull request #538 from KristofferC/kc/promote

fix promotion by avoiding integer division in mse and crossentropy
This commit is contained in:
Mike J Innes 2019-01-15 13:20:46 +00:00 committed by GitHub
commit 67d9016319
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 15 additions and 3 deletions

View File

@ -2,16 +2,16 @@ using NNlib: logsoftmax, logσ
# Cost functions
mse(, y) = sum(( .- y).^2)/length(y)
mse(, y) = sum(( .- y).^2) * 1 // length(y)
function crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
-sum(y .* log.() .* weight) / size(y, 2)
-sum(y .* log.() .* weight) * 1 // size(y, 2)
end
@deprecate logloss(x, y) crossentropy(x, y)
function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
return -sum(y .* logsoftmax(logŷ) .* weight) / size(y, 2)
return -sum(y .* logsoftmax(logŷ) .* weight) * 1 // size(y, 2)
end
"""

View File

@ -49,4 +49,16 @@ const ϵ = 1e-7
@testset "logitbinarycrossentropy" begin
@test logitbinarycrossentropy.(logŷ, y) binarycrossentropy.(σ.(logŷ), y; ϵ=0)
end
@testset "no spurious promotions" begin
for T in (Float16, Float32, Float64)
y = rand(T, 2)
ŷ = rand(T, 2)
for f in (mse, crossentropy, logitcrossentropy)
fwd, back = Flux.Tracker.forward(mse, , y)
@test typeof(fwd) == Flux.Tracker.TrackedReal{T}
@test eltype(back(one(T))[1]) == Flux.Tracker.TrackedReal{T}
end
end
end
end