From feb72d400ad75cfdb64374d9ee95d4b3a0d49ab3 Mon Sep 17 00:00:00 2001 From: cossio Date: Thu, 7 May 2020 12:44:32 +0200 Subject: [PATCH] NaN --- src/layers/stateless.jl | 8 ++++---- test/layers/stateless.jl | 21 +++++++++++++-------- 2 files changed, 17 insertions(+), 12 deletions(-) diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index 57c7398d..bf020688 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -269,11 +269,11 @@ Return `x * log(x)` for `x ≥ 0`, handling `x = 0` by taking the downward limit """ function xlogx(x) result = x * log(x) - ifelse(x > zero(x), result, zero(result)) + ifelse(iszero(x), zero(result), result) end CuArrays.@cufunc function xlogx(x) result = x * log(x) - ifelse(x > zero(x), result, zero(result)) + ifelse(iszero(x), zero(result), result) end """ @@ -282,9 +282,9 @@ Return `x * log(y)` for `y > 0` with correct limit at `x = 0`. """ function xlogy(x, y) result = x * log(y) - ifelse(x > zero(x), result, zero(result)) + ifelse(iszero(x), zero(result), result) end CuArrays.@cufunc function xlogy(x, y) result = x * log(y) - ifelse(x > zero(x), result, zero(result)) + ifelse(iszero(x), zero(result), result) end diff --git a/test/layers/stateless.jl b/test/layers/stateless.jl index bc83d27b..a61e912a 100644 --- a/test/layers/stateless.jl +++ b/test/layers/stateless.jl @@ -6,14 +6,19 @@ using Flux: onehotbatch, mse, crossentropy, logitcrossentropy, const ϵ = 1e-7 @testset "xlogx & xlogy" begin - @test iszero(xlogx(0)) - @test xlogx(2) ≈ 2.0 * log(2.0) - @inferred xlogx(2) - @inferred xlogx(0) - @test iszero(xlogy(0, 1)) - @test xlogy(2, 3) ≈ 2.0 * log(3.0) - @inferred xlogy(2, 3) - @inferred xlogy(0, 1) + @test iszero(xlogx(0)) + @test isnan(xlogx(NaN)) + @test xlogx(2) ≈ 2.0 * log(2.0) + @inferred xlogx(2) + @inferred xlogx(0) + + @test iszero(xlogy(0, 1)) + @test isnan(xlogy(NaN, 1)) + @test isnan(xlogy(1, NaN)) + @test isnan(xlogy(NaN, NaN)) + @test xlogy(2, 3) ≈ 2.0 * log(3.0) + @inferred xlogy(2, 3) + @inferred xlogy(0, 1) end @testset "losses" begin