better tests

This commit is contained in:
Dhairya Gandhi 2019-02-04 18:43:25 +05:30
parent 6654ebfc90
commit 2f916f9763
1 changed files with 4 additions and 3 deletions

View File

@ -39,9 +39,10 @@ Flux.back!(sum(l))
end
@testset "onecold gpu" begin
x = rand(Float32, 10, 3) |> gpu;
y = Flux.onehotbatch(1:3, 1:10) |> gpu;
@test_nowarn Flux.onecold(x) .== Flux.onecold(y)
x = zeros(Float32, 10, 3) |> gpu;
y = Flux.onehotbatch(ones(3), 1:10) |> gpu;
res = Flux.onecold(x) .== Flux.onecold(y)
@test res isa CuArray
end
if CuArrays.libcudnn != nothing