fix promotion by avoiding integer division in mse and crossentropy
oops add tests
This commit is contained in:
parent
9781f063aa
commit
c74aa67c5d
@ -2,16 +2,16 @@ using NNlib: logsoftmax, logσ
|
|||||||
|
|
||||||
# Cost functions
|
# Cost functions
|
||||||
|
|
||||||
mse(ŷ, y) = sum((ŷ .- y).^2)/length(y)
|
mse(ŷ, y) = sum((ŷ .- y).^2) * 1 // length(y)
|
||||||
|
|
||||||
function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
|
function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
|
||||||
-sum(y .* log.(ŷ) .* weight) / size(y, 2)
|
-sum(y .* log.(ŷ) .* weight) * 1 // size(y, 2)
|
||||||
end
|
end
|
||||||
|
|
||||||
@deprecate logloss(x, y) crossentropy(x, y)
|
@deprecate logloss(x, y) crossentropy(x, y)
|
||||||
|
|
||||||
function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
|
function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
|
||||||
return -sum(y .* logsoftmax(logŷ) .* weight) / size(y, 2)
|
return -sum(y .* logsoftmax(logŷ) .* weight) * 1 // size(y, 2)
|
||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -49,4 +49,16 @@ const ϵ = 1e-7
|
|||||||
@testset "logitbinarycrossentropy" begin
|
@testset "logitbinarycrossentropy" begin
|
||||||
@test logitbinarycrossentropy.(logŷ, y) ≈ binarycrossentropy.(σ.(logŷ), y; ϵ=0)
|
@test logitbinarycrossentropy.(logŷ, y) ≈ binarycrossentropy.(σ.(logŷ), y; ϵ=0)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@testset "no spurious promotions" begin
|
||||||
|
for T in (Float16, Float32, Float64)
|
||||||
|
y = rand(T, 2)
|
||||||
|
ŷ = rand(T, 2)
|
||||||
|
for f in (mse, crossentropy, logitcrossentropy)
|
||||||
|
fwd, back = Flux.Tracker.forward(mse, ŷ, y)
|
||||||
|
@test typeof(fwd) == Flux.Tracker.TrackedReal{T}
|
||||||
|
@test eltype(back(one(T))[1]) == Flux.Tracker.TrackedReal{T}
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user