assert no scalar indexing for onecold
This commit is contained in:
parent
35cd9761a8
commit
1ada9afe81
@ -39,6 +39,7 @@ Flux.back!(sum(l))
|
|||||||
end
|
end
|
||||||
|
|
||||||
@testset "onecold gpu" begin
|
@testset "onecold gpu" begin
|
||||||
|
CuArrays.allowscalar(false)
|
||||||
y = Flux.onehotbatch(ones(3), 1:10) |> gpu;
|
y = Flux.onehotbatch(ones(3), 1:10) |> gpu;
|
||||||
@test Flux.onecold(y) isa CuArray
|
@test Flux.onecold(y) isa CuArray
|
||||||
@test y[3,:] isa CuArray
|
@test y[3,:] isa CuArray
|
||||||
|
Loading…
Reference in New Issue
Block a user