Merge pull request #538 from KristofferC/kc/promote
fix promotion by avoiding integer division in mse and crossentropy
This commit is contained in:
commit
67d9016319
|
@ -2,16 +2,16 @@ using NNlib: logsoftmax, logσ
|
|||
|
||||
# Cost functions
|
||||
|
||||
mse(ŷ, y) = sum((ŷ .- y).^2)/length(y)
|
||||
mse(ŷ, y) = sum((ŷ .- y).^2) * 1 // length(y)
|
||||
|
||||
function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
|
||||
-sum(y .* log.(ŷ) .* weight) / size(y, 2)
|
||||
-sum(y .* log.(ŷ) .* weight) * 1 // size(y, 2)
|
||||
end
|
||||
|
||||
@deprecate logloss(x, y) crossentropy(x, y)
|
||||
|
||||
function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
|
||||
return -sum(y .* logsoftmax(logŷ) .* weight) / size(y, 2)
|
||||
return -sum(y .* logsoftmax(logŷ) .* weight) * 1 // size(y, 2)
|
||||
end
|
||||
|
||||
"""
|
||||
|
|
|
@ -49,4 +49,16 @@ const ϵ = 1e-7
|
|||
@testset "logitbinarycrossentropy" begin
|
||||
@test logitbinarycrossentropy.(logŷ, y) ≈ binarycrossentropy.(σ.(logŷ), y; ϵ=0)
|
||||
end
|
||||
|
||||
@testset "no spurious promotions" begin
|
||||
for T in (Float16, Float32, Float64)
|
||||
y = rand(T, 2)
|
||||
ŷ = rand(T, 2)
|
||||
for f in (mse, crossentropy, logitcrossentropy)
|
||||
fwd, back = Flux.Tracker.forward(mse, ŷ, y)
|
||||
@test typeof(fwd) == Flux.Tracker.TrackedReal{T}
|
||||
@test eltype(back(one(T))[1]) == Flux.Tracker.TrackedReal{T}
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue