added onecold broadcast test
This commit is contained in:
parent
30aa814c4d
commit
6654ebfc90
|
@ -38,6 +38,12 @@ Flux.back!(sum(l))
|
|||
|
||||
end
|
||||
|
||||
@testset "onecold gpu" begin
|
||||
x = rand(Float32, 10, 3) |> gpu;
|
||||
y = Flux.onehotbatch(1:3, 1:10) |> gpu;
|
||||
@test_nowarn Flux.onecold(x) .== Flux.onecold(y)
|
||||
end
|
||||
|
||||
if CuArrays.libcudnn != nothing
|
||||
@info "Testing Flux/CUDNN"
|
||||
include("cudnn.jl")
|
||||
|
|
Loading…
Reference in New Issue