Fix logitbinarycrossentropy on CuArrays

This commit is contained in:
matsueushi 2019-11-22 05:23:24 +00:00
parent 5839e166f6
commit a0314ce682
2 changed files with 6 additions and 2 deletions

View File

@ -53,6 +53,9 @@ but it is more numerically stable.
"""
logitbinarycrossentropy(logŷ, y) = (1 - y)*logŷ - logσ(logŷ)
# Re-definition to fix interaction with CuArrays.
CuArrays.@cufunc logitbinarycrossentropy(logŷ, y) = (1 - y)*logŷ - logσ(logŷ)
"""
normalise(x::AbstractArray; dims=1)

View File

@ -31,9 +31,10 @@ cx = gpu(x)
@test Flux.crossentropy(x,x, weight=1.0) Flux.crossentropy(cx,cx, weight=1.0)
@test Flux.crossentropy(x,x, weight=[1.0;2.0;3.0]) Flux.crossentropy(cx,cx, weight=cu([1.0;2.0;3.0]))
x = σ.([-1.1491, 0.8619, 0.3127])
x = [-1.1491, 0.8619, 0.3127]
y = [1, 1, 0.]
@test Flux.binarycrossentropy.(x,y) Flux.binarycrossentropy.(cu(x),cu(y))
@test Flux.binarycrossentropy.(σ.(x),y) Flux.binarycrossentropy.(cu(σ.(x)),cu(y))
@test Flux.logitbinarycrossentropy.(x,y) Flux.logitbinarycrossentropy.(cu(x),cu(y))
xs = rand(5, 5)
ys = Flux.onehotbatch(1:5,1:5)