fix gpu cross entropy

This commit is contained in:
Mike J Innes 2018-04-17 17:20:51 +01:00
parent d12fb98f2a
commit 8f73dc6e14
2 changed files with 5 additions and 1 deletions

View File

@ -5,7 +5,7 @@ using NNlib: logsoftmax, logσ
mse(, y) = sum(( .- y).^2)/length(y)
function crossentropy(::AbstractVecOrMat, y::AbstractVecOrMat; weight = 1)
return @fix -sum(y .* log.() .* weight) / size(y, 2)
@fix -sum(y .* log.() .* weight) / size(y, 2)
end
@deprecate logloss(x, y) crossentropy(x, y)

View File

@ -21,6 +21,10 @@ cm = gpu(m)
@test all(p isa TrackedArray && p.data isa CuArray for p in params(cm))
@test cm(gpu(rand(10, 10))) isa TrackedArray{Float32,2,CuArray{Float32,2}}
x = [1,2,3]
cx = gpu(x)
@test Flux.crossentropy(x,x) Flux.crossentropy(cx,cx)
# Fails in Pkg.test ffs
# c = gpu(Conv((2,2),3=>4))
# l = c(gpu(rand(10,10,3,2)))