Add test for 2D inputs

This commit is contained in:
Avik Pal 2018-11-10 11:52:23 +05:30
parent d6aacf4135
commit 4df9e10516
1 changed files with 38 additions and 14 deletions

View File

@ -2,23 +2,47 @@ using Flux, Flux.Tracker, CuArrays, Test
using Flux.Tracker: TrackedArray, data
@testset "CUDNN BatchNorm" begin
x = TrackedArray(rand(10, 10, 3, 1))
m = BatchNorm(3)
cx = gpu(x)
cm = gpu(m)
@testset "4D Input" begin
x = TrackedArray(rand(10, 10, 3, 1))
m = BatchNorm(3)
cx = gpu(x)
cm = gpu(m)
y = m(x)
cy = cm(cx)
y = m(x)
cy = cm(cx)
@test cy isa TrackedArray{Float32,4,CuArray{Float32,4}}
@test cy isa TrackedArray{Float32,4,CuArray{Float32,4}}
@test cpu(data(cy)) data(y)
@test cpu(data(cy)) data(y)
g = rand(size(y)...)
Flux.back!(y, g)
Flux.back!(cy, gpu(g))
g = rand(size(y)...)
Flux.back!(y, g)
Flux.back!(cy, gpu(g))
@test m.γ.grad cpu(cm.γ.grad)
@test m.β.grad cpu(cm.β.grad)
@test x.grad cpu(x.grad)
@test m.γ.grad cpu(cm.γ.grad)
@test m.β.grad cpu(cm.β.grad)
@test x.grad cpu(x.grad)
end
@testset "2D Input" begin
x = TrackedArray(rand(3, 1))
m = BatchNorm(3)
cx = gpu(x)
cm = gpu(m)
y = m(x)
cy = cm(cx)
@test cy isa TrackedArray{Float32,2,CuArray{Float32,2}}
@test cpu(data(cy)) data(y)
g = rand(size(y)...)
Flux.back!(y, g)
Flux.back!(cy, gpu(g))
@test m.γ.grad cpu(cm.γ.grad)
@test m.β.grad cpu(cm.β.grad)
@test x.grad cpu(x.grad)
end
end