Updated tests in cudnn.jl

This commit is contained in:
thebhatman 2019-06-13 18:44:46 +05:30
parent 25f74d1b4a
commit 80c680c598

View File

@ -1,48 +1,47 @@
using Flux, CuArrays, Test using Flux, CuArrays, Test
using Zygote
trainmode(f, x...) = forward(f, x...)[1] trainmode(f, x...) = forward(f, x...)[1]
#
# @testset "CUDNN BatchNorm" begin @testset "CUDNN BatchNorm" begin
# @testset "4D Input" begin @testset "4D Input" begin
# x = Float64.(collect(reshape(1:12, 2, 2, 3, 1))) x = Float64.(collect(reshape(1:12, 2, 2, 3, 1)))
# m = BatchNorm(3) m = BatchNorm(3)
# cx = gpu(x) cx = gpu(x)
# cm = gpu(m) cm = gpu(m)
#
# y = trainmode(m, x) y = trainmode(m, x)
# cy = trainmode(cm, cx) cy = trainmode(cm, cx)
#
# # @test cy isa TrackedArray{Float32,4,CuArray{Float32,4}} @test cpu(data(cy)) data(y)
#
# @test cpu(data(cy)) ≈ data(y) g = rand(size(y)...)
# # Flux.back!(y, g)
# g = rand(size(y)...) # Flux.back!(cy, gpu(g))
# Flux.back!(y, g)
# Flux.back!(cy, gpu(g)) @test m.γ cpu(cm.γ)
# @test m.β cpu(cm.β)
# @test m.γ.grad ≈ cpu(cm.γ.grad) @test x cpu(x)
# @test m.β.grad ≈ cpu(cm.β.grad) end
# @test x.grad ≈ cpu(x.grad)
# end @testset "2D Input" begin
# x = Float64.(collect(reshape(1:12, 3, 4)))
# @testset "2D Input" begin m = BatchNorm(3)
# x = TrackedArray(Float64.(collect(reshape(1:12, 3, 4)))) cx = gpu(x)
# m = BatchNorm(3) cm = gpu(m)
# cx = gpu(x)
# cm = gpu(m) y = trainmode(m, x)
# cy = trainmode(cm, cx)
# y = m(x)
# cy = cm(cx) @test cy isa CuArray{Float32,2}
#
# @test cy isa TrackedArray{Float32,2,CuArray{Float32,2}} @test cpu(data(cy)) data(y)
#
# @test cpu(data(cy)) ≈ data(y) g = rand(size(y)...)
# #Flux.back!(y, g)
# g = rand(size(y)...) #Flux.back!(cy, gpu(g))
# Flux.back!(y, g)
# Flux.back!(cy, gpu(g)) @test m.γ cpu(cm.γ)
# @test m.β cpu(cm.β)
# @test m.γ.grad ≈ cpu(cm.γ.grad) @test x cpu(x)
# @test m.β.grad ≈ cpu(cm.β.grad) end
# @test x.grad ≈ cpu(x.grad) end
# end
# end