Don't invoke GPU crossentropy with integers.
Broadcasting log on integers does not work.
This commit is contained in:
parent
e66a7f130f
commit
e2c2ec5575
|
@ -25,7 +25,7 @@ cm = gpu(m)
|
|||
@test all(p isa CuArray for p in params(cm))
|
||||
@test cm(gpu(rand(10, 10))) isa CuArray{Float32,2}
|
||||
|
||||
x = [1,2,3]
|
||||
x = [1.,2.,3.]
|
||||
cx = gpu(x)
|
||||
@test Flux.crossentropy(x,x) ≈ Flux.crossentropy(cx,cx)
|
||||
@test Flux.crossentropy(x,x, weight=1.0) ≈ Flux.crossentropy(cx,cx, weight=1.0)
|
||||
|
|
Loading…
Reference in New Issue