test onecold-of-tracked-gpu-vector

see #556
This commit is contained in:
Mike Innes 2019-01-24 10:40:52 +00:00
parent 62d780c77f
commit 0142d89943
1 changed files with 2 additions and 0 deletions

View File

@ -11,6 +11,8 @@ x = param(randn(5, 5))
cx = gpu(x)
@test cx isa TrackedArray && cx.data isa CuArray
@test Flux.onecold(param(gpu([1.,2.,3.]))) == 3
x = Flux.onehotbatch([1, 2, 3], 1:3)
cx = gpu(x)
@test cx isa Flux.OneHotMatrix && cx.data isa CuArray