xlogy broadcast adjoint
This commit is contained in:
parent
bd43201f37
commit
f5a8900ffb
|
@ -288,3 +288,8 @@ CuArrays.@cufunc function xlogy(x, y)
|
|||
result = x * log(y)
|
||||
ifelse(iszero(x), zero(result), result)
|
||||
end
|
||||
|
||||
@adjoint function broadcasted(::typeof(xlogy), x::Zygote.Numeric, y::Zygote.Numeric)
|
||||
res = xlogy.(x, y)
|
||||
res, Δ -> (nothing, Zygote.unbroadcast(x, xlogy.(Δ, y)), Zygote.unbroadcast(y, Δ .* x ./ y))
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue