diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index 860bae9b..50f64b52 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -288,3 +288,8 @@ CuArrays.@cufunc function xlogy(x, y) result = x * log(y) ifelse(iszero(x), zero(result), result) end + +@adjoint function broadcasted(::typeof(xlogy), x::Zygote.Numeric, y::Zygote.Numeric) + res = xlogy.(x, y) + res, Δ -> (nothing, Zygote.unbroadcast(x, xlogy.(Δ, y)), Zygote.unbroadcast(y, Δ .* x ./ y)) +end