better tests
This commit is contained in:
parent
6654ebfc90
commit
2f916f9763
@ -39,9 +39,10 @@ Flux.back!(sum(l))
|
|||||||
end
|
end
|
||||||
|
|
||||||
@testset "onecold gpu" begin
|
@testset "onecold gpu" begin
|
||||||
x = rand(Float32, 10, 3) |> gpu;
|
x = zeros(Float32, 10, 3) |> gpu;
|
||||||
y = Flux.onehotbatch(1:3, 1:10) |> gpu;
|
y = Flux.onehotbatch(ones(3), 1:10) |> gpu;
|
||||||
@test_nowarn Flux.onecold(x) .== Flux.onecold(y)
|
res = Flux.onecold(x) .== Flux.onecold(y)
|
||||||
|
@test res isa CuArray
|
||||||
end
|
end
|
||||||
|
|
||||||
if CuArrays.libcudnn != nothing
|
if CuArrays.libcudnn != nothing
|
||||||
|
Loading…
Reference in New Issue
Block a user