2019-03-08 14:49:28 +00:00
|
|
|
|
using Flux, CuArrays, Test
|
2019-09-19 17:33:33 +00:00
|
|
|
|
using Flux: pullback
|
2019-06-13 13:14:46 +00:00
|
|
|
|
|
|
|
|
|
@testset "CUDNN BatchNorm" begin
|
|
|
|
|
@testset "4D Input" begin
|
|
|
|
|
x = Float64.(collect(reshape(1:12, 2, 2, 3, 1)))
|
|
|
|
|
m = BatchNorm(3)
|
|
|
|
|
cx = gpu(x)
|
|
|
|
|
cm = gpu(m)
|
|
|
|
|
|
2019-09-19 17:33:33 +00:00
|
|
|
|
y, back = pullback((m, x) -> m(x), m, x)
|
|
|
|
|
cy, cback = pullback((m, x) -> m(x), cm, cx)
|
2019-06-13 13:14:46 +00:00
|
|
|
|
|
2019-08-19 14:09:32 +00:00
|
|
|
|
@test cpu(cy) ≈ y
|
2019-06-13 13:14:46 +00:00
|
|
|
|
|
2019-08-19 15:56:48 +00:00
|
|
|
|
Δ = randn(size(y))
|
|
|
|
|
dm, dx = back(Δ)
|
|
|
|
|
cdm, cdx = cback(gpu(Δ))
|
2019-06-13 13:14:46 +00:00
|
|
|
|
|
2019-08-19 15:56:48 +00:00
|
|
|
|
@test dm[].γ ≈ cpu(cdm[].γ)
|
|
|
|
|
@test dm[].β ≈ cpu(cdm[].β)
|
|
|
|
|
@test dx ≈ cpu(cdx)
|
2019-06-13 13:14:46 +00:00
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
@testset "2D Input" begin
|
|
|
|
|
x = Float64.(collect(reshape(1:12, 3, 4)))
|
|
|
|
|
m = BatchNorm(3)
|
|
|
|
|
cx = gpu(x)
|
|
|
|
|
cm = gpu(m)
|
|
|
|
|
|
2019-09-19 17:33:33 +00:00
|
|
|
|
y, back = pullback((m, x) -> m(x), m, x)
|
|
|
|
|
cy, cback = pullback((m, x) -> m(x), cm, cx)
|
2019-06-13 13:14:46 +00:00
|
|
|
|
|
2019-08-19 14:09:32 +00:00
|
|
|
|
@test cpu(cy) ≈ y
|
2019-06-13 13:14:46 +00:00
|
|
|
|
|
2019-08-19 15:56:48 +00:00
|
|
|
|
Δ = randn(size(y))
|
|
|
|
|
dm, dx = back(Δ)
|
|
|
|
|
cdm, cdx = cback(gpu(Δ))
|
2019-06-13 13:14:46 +00:00
|
|
|
|
|
2019-08-19 15:56:48 +00:00
|
|
|
|
@test dm[].γ ≈ cpu(cdm[].γ)
|
|
|
|
|
@test dm[].β ≈ cpu(cdm[].β)
|
|
|
|
|
@test dx ≈ cpu(cdx)
|
2019-06-13 13:14:46 +00:00
|
|
|
|
end
|
|
|
|
|
end
|