Merge pull request #926 from janEbert/bc-cuda-fix
Fix binarycrossentropy on CuArrays
This commit is contained in:
commit
9d6f6fdaa3
@ -1,3 +1,4 @@
|
||||
using CuArrays
|
||||
using NNlib: logsoftmax, logσ
|
||||
|
||||
# Cost functions
|
||||
@ -35,6 +36,9 @@ Return `-y*log(ŷ + ϵ) - (1-y)*log(1-ŷ + ϵ)`. The ϵ term provides numerica
|
||||
"""
|
||||
binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ)
|
||||
|
||||
# Re-definition to fix interaction with CuArrays.
|
||||
CuArrays.@cufunc binarycrossentropy(ŷ, y; ϵ=eps(ŷ)) = -y*log(ŷ + ϵ) - (1 - y)*log(1 - ŷ + ϵ)
|
||||
|
||||
"""
|
||||
logitbinarycrossentropy(logŷ, y)
|
||||
|
||||
|
@ -31,6 +31,10 @@ cx = gpu(x)
|
||||
@test Flux.crossentropy(x,x, weight=1.0) ≈ Flux.crossentropy(cx,cx, weight=1.0)
|
||||
@test Flux.crossentropy(x,x, weight=[1.0;2.0;3.0]) ≈ Flux.crossentropy(cx,cx, weight=cu([1.0;2.0;3.0]))
|
||||
|
||||
x = σ.([-1.1491, 0.8619, 0.3127])
|
||||
y = [1, 1, 0.]
|
||||
@test Flux.binarycrossentropy.(x,y) ≈ Flux.binarycrossentropy.(cu(x),cu(y))
|
||||
|
||||
xs = rand(5, 5)
|
||||
ys = Flux.onehotbatch(1:5,1:5)
|
||||
@test collect(cu(xs) .+ cu(ys)) ≈ collect(xs .+ ys)
|
||||
|
Loading…
Reference in New Issue
Block a user