diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index c8e1b793..13e165f0 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -264,14 +264,19 @@ function flatten(x::AbstractArray) end """ - xlogx(x::Real) + xlogx(x) Return `x * log(x)` for `x ≥ 0`, handling `x = 0` by taking the downward limit. """ -xlogx(x::Real) = x > zero(x) ? x * log(x) : zero(log(x)) +function xlogx(x) + result = x * log(x) + ifelse(x > zero(x), result, zero(result)) +end """ - xlogy(x::Real, y::Real) + xlogy(x, y) Return `x * log(y)` for `y > 0` with correct limit at `x = 0`. """ -xlogy(x::T, y::T) where {T<:Real} = x > zero(T) ? x * log(y) : zero(log(x)) -xlogy(x::Real, y::Real) = xlogy(promote(x, y)...) +function xlogy(x, y) + result = x * log(y) + ifelse(x > zero(x), result, zero(result)) +end diff --git a/test/layers/stateless.jl b/test/layers/stateless.jl index 4e66c39f..bc83d27b 100644 --- a/test/layers/stateless.jl +++ b/test/layers/stateless.jl @@ -8,9 +8,12 @@ const ϵ = 1e-7 @testset "xlogx & xlogy" begin @test iszero(xlogx(0)) @test xlogx(2) ≈ 2.0 * log(2.0) - + @inferred xlogx(2) + @inferred xlogx(0) @test iszero(xlogy(0, 1)) @test xlogy(2, 3) ≈ 2.0 * log(3.0) + @inferred xlogy(2, 3) + @inferred xlogy(0, 1) end @testset "losses" begin