Flux.jl/test/cuda/cuda.jl

45 lines
1.0 KiB
Julia
Raw Normal View History

2018-07-18 13:39:20 +00:00
using Flux, Flux.Tracker, CuArrays, Test
2018-02-28 22:51:08 +00:00
using Flux: gpu
2018-01-16 17:58:14 +00:00
2018-09-11 11:28:05 +00:00
# @info "Testing GPU Support"
#
# @testset "CuArrays" begin
#
# CuArrays.allowscalar(false)
#
# x = param(randn(5, 5))
# cx = gpu(x)
# @test cx isa TrackedArray && cx.data isa CuArray
#
# x = Flux.onehotbatch([1, 2, 3], 1:3)
# cx = gpu(x)
# @test cx isa Flux.OneHotMatrix && cx.data isa CuArray
# @test (cx .+ 1) isa CuArray
#
# m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax)
# cm = gpu(m)
#
# @test all(p isa TrackedArray && p.data isa CuArray for p in params(cm))
# @test cm(gpu(rand(10, 10))) isa TrackedArray{Float32,2,CuArray{Float32,2}}
#
# x = [1,2,3]
# cx = gpu(x)
# @test Flux.crossentropy(x,x) ≈ Flux.crossentropy(cx,cx)
#
# xs = param(rand(5,5))
# ys = Flux.onehotbatch(1:5,1:5)
# @test collect(cu(xs) .+ cu(ys)) ≈ collect(xs .+ ys)
#
# c = gpu(Conv((2,2),3=>4))
# l = c(gpu(rand(10,10,3,2)))
# Flux.back!(sum(l))
#
# end
2018-01-30 13:12:33 +00:00
2018-06-22 12:49:18 +00:00
if CuArrays.cudnn_available()
2018-09-11 11:28:05 +00:00
@info "Testing Flux/CUDNN BatchNorm"
2018-06-22 12:49:18 +00:00
include("cudnn.jl")
2018-09-11 11:28:05 +00:00
@info "Testing Flux/CUDNN RNN"
2018-06-22 12:49:18 +00:00
include("curnn.jl")
end