assert no scalar indexing for onecold
This commit is contained in:
parent
35cd9761a8
commit
1ada9afe81
|
@ -39,6 +39,7 @@ Flux.back!(sum(l))
|
|||
end
|
||||
|
||||
@testset "onecold gpu" begin
|
||||
CuArrays.allowscalar(false)
|
||||
y = Flux.onehotbatch(ones(3), 1:10) |> gpu;
|
||||
@test Flux.onecold(y) isa CuArray
|
||||
@test y[3,:] isa CuArray
|
||||
|
|
Loading…
Reference in New Issue