Flux.jl/test/cuda/cuda.jl

54 lines
1.2 KiB
Julia
Raw Normal View History

2018-07-18 13:39:20 +00:00
using Flux, Flux.Tracker, CuArrays, Test
2018-02-28 22:51:08 +00:00
using Flux: gpu
2018-01-16 17:58:14 +00:00
2018-09-05 14:39:00 +00:00
@info "Testing GPU Support"
2018-01-24 13:12:22 +00:00
2018-01-16 17:58:14 +00:00
@testset "CuArrays" begin
CuArrays.allowscalar(false)
x = param(randn(5, 5))
2018-02-28 22:51:08 +00:00
cx = gpu(x)
2018-01-16 17:58:14 +00:00
@test cx isa TrackedArray && cx.data isa CuArray
@test Flux.onecold(param(gpu([1.,2.,3.]))) == 3
2018-01-16 17:58:14 +00:00
x = Flux.onehotbatch([1, 2, 3], 1:3)
2018-02-28 22:51:08 +00:00
cx = gpu(x)
2018-01-16 17:58:14 +00:00
@test cx isa Flux.OneHotMatrix && cx.data isa CuArray
2018-08-20 12:08:04 +00:00
@test (cx .+ 1) isa CuArray
2018-01-16 17:58:14 +00:00
2018-03-01 16:37:52 +00:00
m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)
2018-02-28 22:51:08 +00:00
cm = gpu(m)
2018-01-16 17:58:14 +00:00
@test all(p isa TrackedArray && p.data isa CuArray for p in params(cm))
2018-02-28 22:51:08 +00:00
@test cm(gpu(rand(10, 10))) isa TrackedArray{Float32,2,CuArray{Float32,2}}
2018-01-16 17:58:14 +00:00
2018-04-17 16:20:51 +00:00
x = [1,2,3]
cx = gpu(x)
@test Flux.crossentropy(x,x) Flux.crossentropy(cx,cx)
2018-08-24 13:30:39 +00:00
xs = param(rand(5,5))
ys = Flux.onehotbatch(1:5,1:5)
@test collect(cu(xs) .+ cu(ys)) collect(xs .+ ys)
2018-08-20 14:38:25 +00:00
c = gpu(Conv((2,2),3=>4))
l = c(gpu(rand(10,10,3,2)))
Flux.back!(sum(l))
2018-02-28 23:18:49 +00:00
2018-01-16 17:58:14 +00:00
end
2018-01-30 13:12:33 +00:00
2019-02-09 17:02:02 +00:00
@testset "onecold gpu" begin
y = Flux.onehotbatch(ones(3), 1:10) |> gpu;
@test Flux.onecold(y) isa CuArray
@test y[3,:] isa CuArray
end
2018-10-23 16:23:29 +00:00
if CuArrays.libcudnn != nothing
2018-11-27 23:44:07 +00:00
@info "Testing Flux/CUDNN"
2018-06-22 12:49:18 +00:00
include("cudnn.jl")
2019-04-03 10:31:27 +00:00
if !haskey(ENV, "CI_DISABLE_CURNN_TEST")
2019-04-01 14:26:49 +00:00
include("curnn.jl")
end
2018-06-22 12:49:18 +00:00
end