Flux.jl/test/cuda/cudnn.jl

9 lines
220 B
Julia
Raw Normal View History

2018-06-22 12:49:18 +00:00
using Flux, Flux.Tracker, CuArrays, Base.Test
using Flux: gpu
2018-01-30 13:12:33 +00:00
2018-06-22 12:49:18 +00:00
@testset "CUDNN BatchNorm" begin
x = gpu(rand(10, 10, 3, 1))
m = gpu(BatchNorm(3))
@test m(x) isa TrackedArray{Float32,4,CuArray{Float32,4}}
2018-01-30 13:12:33 +00:00
end