diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index 13e165f0..57c7398d 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -271,6 +271,10 @@ function xlogx(x) result = x * log(x) ifelse(x > zero(x), result, zero(result)) end +CuArrays.@cufunc function xlogx(x) + result = x * log(x) + ifelse(x > zero(x), result, zero(result)) +end """ xlogy(x, y) @@ -280,3 +284,7 @@ function xlogy(x, y) result = x * log(y) ifelse(x > zero(x), result, zero(result)) end +CuArrays.@cufunc function xlogy(x, y) + result = x * log(y) + ifelse(x > zero(x), result, zero(result)) +end