2019-03-08 14:49:28 +00:00
|
|
|
|
using Flux, CuArrays, Test
|
2019-06-13 13:14:46 +00:00
|
|
|
|
using Zygote
|
2019-06-12 16:04:42 +00:00
|
|
|
|
trainmode(f, x...) = forward(f, x...)[1]
|
2019-06-13 13:14:46 +00:00
|
|
|
|
|
|
|
|
|
@testset "CUDNN BatchNorm" begin
|
|
|
|
|
@testset "4D Input" begin
|
|
|
|
|
x = Float64.(collect(reshape(1:12, 2, 2, 3, 1)))
|
|
|
|
|
m = BatchNorm(3)
|
|
|
|
|
cx = gpu(x)
|
|
|
|
|
cm = gpu(m)
|
|
|
|
|
|
|
|
|
|
y = trainmode(m, x)
|
|
|
|
|
cy = trainmode(cm, cx)
|
|
|
|
|
|
|
|
|
|
@test cpu(data(cy)) ≈ data(y)
|
|
|
|
|
|
|
|
|
|
g = rand(size(y)...)
|
|
|
|
|
# Flux.back!(y, g)
|
|
|
|
|
# Flux.back!(cy, gpu(g))
|
|
|
|
|
|
|
|
|
|
@test m.γ ≈ cpu(cm.γ)
|
|
|
|
|
@test m.β ≈ cpu(cm.β)
|
|
|
|
|
@test x ≈ cpu(x)
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
@testset "2D Input" begin
|
|
|
|
|
x = Float64.(collect(reshape(1:12, 3, 4)))
|
|
|
|
|
m = BatchNorm(3)
|
|
|
|
|
cx = gpu(x)
|
|
|
|
|
cm = gpu(m)
|
|
|
|
|
|
|
|
|
|
y = trainmode(m, x)
|
|
|
|
|
cy = trainmode(cm, cx)
|
|
|
|
|
|
|
|
|
|
@test cy isa CuArray{Float32,2}
|
|
|
|
|
|
|
|
|
|
@test cpu(data(cy)) ≈ data(y)
|
|
|
|
|
|
|
|
|
|
g = rand(size(y)...)
|
|
|
|
|
#Flux.back!(y, g)
|
|
|
|
|
#Flux.back!(cy, gpu(g))
|
|
|
|
|
|
|
|
|
|
@test m.γ ≈ cpu(cm.γ)
|
|
|
|
|
@test m.β ≈ cpu(cm.β)
|
|
|
|
|
@test x ≈ cpu(x)
|
|
|
|
|
end
|
|
|
|
|
end
|