1175: xlogy broadcast adjoint r=MikeInnes a=MikeInnes

This is helpful for performance, since it avoids having to differentiate `xlogy` itself inside of a map.

Co-authored-by: Mike J Innes <mike.j.innes@gmail.com>
This commit is contained in:
bors[bot] 2020-05-12 17:10:58 +00:00 committed by GitHub
commit de39d1095b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 5 additions and 0 deletions

View File

@ -288,3 +288,8 @@ CuArrays.@cufunc function xlogy(x, y)
result = x * log(y)
ifelse(iszero(x), zero(result), result)
end
@adjoint function broadcasted(::typeof(xlogy), x::Zygote.Numeric, y::Zygote.Numeric)
res = xlogy.(x, y)
res, Δ -> (nothing, Zygote.unbroadcast(x, xlogy.(Δ, y)), Zygote.unbroadcast(y, Δ .* x ./ y))
end