From f5a8900ffb50b7b5fede3d3c085740c41ab595ff Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 12 May 2020 17:29:35 +0100 Subject: [PATCH] xlogy broadcast adjoint --- src/layers/stateless.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/layers/stateless.jl b/src/layers/stateless.jl index 860bae9b..50f64b52 100644 --- a/src/layers/stateless.jl +++ b/src/layers/stateless.jl @@ -288,3 +288,8 @@ CuArrays.@cufunc function xlogy(x, y) result = x * log(y) ifelse(iszero(x), zero(result), result) end + +@adjoint function broadcasted(::typeof(xlogy), x::Zygote.Numeric, y::Zygote.Numeric) + res = xlogy.(x, y) + res, Δ -> (nothing, Zygote.unbroadcast(x, xlogy.(Δ, y)), Zygote.unbroadcast(y, Δ .* x ./ y)) +end