Modified tests in cuda.jl
This commit is contained in:
parent
1ff4e3188e
commit
25f74d1b4a
|
@ -1,5 +1,6 @@
|
|||
using Flux, CuArrays, Test
|
||||
using Flux: gpu
|
||||
using Zygote
|
||||
|
||||
@info "Testing GPU Support"
|
||||
|
||||
|
@ -9,20 +10,20 @@ CuArrays.allowscalar(false)
|
|||
|
||||
x = param(randn(5, 5))
|
||||
cx = gpu(x)
|
||||
@test cx isa TrackedArray && cx.data isa CuArray
|
||||
@test cx isa CuArray
|
||||
|
||||
@test Flux.onecold(param(gpu([1.,2.,3.]))) == 3
|
||||
|
||||
x = Flux.onehotbatch([1, 2, 3], 1:3)
|
||||
cx = gpu(x)
|
||||
@test cx isa Flux.OneHotMatrix && cx.data isa CuArray
|
||||
@test cx isa Flux.OneHotMatrix && cx isa CuArray
|
||||
@test (cx .+ 1) isa CuArray
|
||||
|
||||
m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)
|
||||
cm = gpu(m)
|
||||
|
||||
@test all(p isa TrackedArray && p.data isa CuArray for p in params(cm))
|
||||
@test cm(gpu(rand(10, 10))) isa TrackedArray{Float32,2,CuArray{Float32,2}}
|
||||
@test all(p isa CuArray for p in params(cm))
|
||||
@test cm(gpu(rand(10, 10))) isa CuArray{Float32,2}
|
||||
|
||||
x = [1,2,3]
|
||||
cx = gpu(x)
|
||||
|
@ -34,11 +35,13 @@ ys = Flux.onehotbatch(1:5,1:5)
|
|||
|
||||
c = gpu(Conv((2,2),3=>4))
|
||||
l = c(gpu(rand(10,10,3,2)))
|
||||
Flux.back!(sum(l))
|
||||
fwd, back = Zygote.forward(sum, l)
|
||||
back(one(Float64))
|
||||
|
||||
c = gpu(CrossCor((2,2),3=>4))
|
||||
l = c(gpu(rand(10,10,3,2)))
|
||||
Flux.back!(sum(l))
|
||||
fwd, back = Zygote.forward(sum, l)
|
||||
back(one(Float64))
|
||||
|
||||
end
|
||||
|
||||
|
|
Loading…
Reference in New Issue