fix gpu cross entropy
This commit is contained in:
parent
d12fb98f2a
commit
8f73dc6e14
|
@ -5,7 +5,7 @@ using NNlib: logsoftmax, logσ
|
|||
mse(ŷ, y) = sum((ŷ .- y).^2)/length(y)
|
||||
|
||||
function crossentropy(ŷ::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
|
||||
return @fix -sum(y .* log.(ŷ) .* weight) / size(y, 2)
|
||||
@fix -sum(y .* log.(ŷ) .* weight) / size(y, 2)
|
||||
end
|
||||
|
||||
@deprecate logloss(x, y) crossentropy(x, y)
|
||||
|
|
|
@ -21,6 +21,10 @@ cm = gpu(m)
|
|||
@test all(p isa TrackedArray && p.data isa CuArray for p in params(cm))
|
||||
@test cm(gpu(rand(10, 10))) isa TrackedArray{Float32,2,CuArray{Float32,2}}
|
||||
|
||||
x = [1,2,3]
|
||||
cx = gpu(x)
|
||||
@test Flux.crossentropy(x,x) ≈ Flux.crossentropy(cx,cx)
|
||||
|
||||
# Fails in Pkg.test ffs
|
||||
# c = gpu(Conv((2,2),3=>4))
|
||||
# l = c(gpu(rand(10,10,3,2)))
|
||||
|
|
Loading…
Reference in New Issue