Fix logitbinarycrossentropy on CuArrays
This commit is contained in:
parent
5839e166f6
commit
a0314ce682
|
@ -53,6 +53,9 @@ but it is more numerically stable.
|
|||
"""
|
||||
logitbinarycrossentropy(logŷ, y) = (1 - y)*logŷ - logσ(logŷ)
|
||||
|
||||
# Re-definition to fix interaction with CuArrays.
|
||||
CuArrays.@cufunc logitbinarycrossentropy(logŷ, y) = (1 - y)*logŷ - logσ(logŷ)
|
||||
|
||||
"""
|
||||
normalise(x::AbstractArray; dims=1)
|
||||
|
||||
|
|
|
@ -31,9 +31,10 @@ cx = gpu(x)
|
|||
@test Flux.crossentropy(x,x, weight=1.0) ≈ Flux.crossentropy(cx,cx, weight=1.0)
|
||||
@test Flux.crossentropy(x,x, weight=[1.0;2.0;3.0]) ≈ Flux.crossentropy(cx,cx, weight=cu([1.0;2.0;3.0]))
|
||||
|
||||
x = σ.([-1.1491, 0.8619, 0.3127])
|
||||
x = [-1.1491, 0.8619, 0.3127]
|
||||
y = [1, 1, 0.]
|
||||
@test Flux.binarycrossentropy.(x,y) ≈ Flux.binarycrossentropy.(cu(x),cu(y))
|
||||
@test Flux.binarycrossentropy.(σ.(x),y) ≈ Flux.binarycrossentropy.(cu(σ.(x)),cu(y))
|
||||
@test Flux.logitbinarycrossentropy.(x,y) ≈ Flux.logitbinarycrossentropy.(cu(x),cu(y))
|
||||
|
||||
xs = rand(5, 5)
|
||||
ys = Flux.onehotbatch(1:5,1:5)
|
||||
|
|
Loading…
Reference in New Issue