This commit is contained in:
cossio 2020-05-07 12:44:32 +02:00
parent 86d6555269
commit feb72d400a
2 changed files with 17 additions and 12 deletions

View File

@ -269,11 +269,11 @@ Return `x * log(x)` for `x ≥ 0`, handling `x = 0` by taking the downward limit
""" """
function xlogx(x) function xlogx(x)
result = x * log(x) result = x * log(x)
ifelse(x > zero(x), result, zero(result)) ifelse(iszero(x), zero(result), result)
end end
CuArrays.@cufunc function xlogx(x) CuArrays.@cufunc function xlogx(x)
result = x * log(x) result = x * log(x)
ifelse(x > zero(x), result, zero(result)) ifelse(iszero(x), zero(result), result)
end end
""" """
@ -282,9 +282,9 @@ Return `x * log(y)` for `y > 0` with correct limit at `x = 0`.
""" """
function xlogy(x, y) function xlogy(x, y)
result = x * log(y) result = x * log(y)
ifelse(x > zero(x), result, zero(result)) ifelse(iszero(x), zero(result), result)
end end
CuArrays.@cufunc function xlogy(x, y) CuArrays.@cufunc function xlogy(x, y)
result = x * log(y) result = x * log(y)
ifelse(x > zero(x), result, zero(result)) ifelse(iszero(x), zero(result), result)
end end

View File

@ -6,14 +6,19 @@ using Flux: onehotbatch, mse, crossentropy, logitcrossentropy,
const ϵ = 1e-7 const ϵ = 1e-7
@testset "xlogx & xlogy" begin @testset "xlogx & xlogy" begin
@test iszero(xlogx(0)) @test iszero(xlogx(0))
@test xlogx(2) 2.0 * log(2.0) @test isnan(xlogx(NaN))
@inferred xlogx(2) @test xlogx(2) 2.0 * log(2.0)
@inferred xlogx(0) @inferred xlogx(2)
@test iszero(xlogy(0, 1)) @inferred xlogx(0)
@test xlogy(2, 3) 2.0 * log(3.0)
@inferred xlogy(2, 3) @test iszero(xlogy(0, 1))
@inferred xlogy(0, 1) @test isnan(xlogy(NaN, 1))
@test isnan(xlogy(1, NaN))
@test isnan(xlogy(NaN, NaN))
@test xlogy(2, 3) 2.0 * log(3.0)
@inferred xlogy(2, 3)
@inferred xlogy(0, 1)
end end
@testset "losses" begin @testset "losses" begin