Make the test more reliable

This commit is contained in:
Avik Pal 2018-11-10 14:00:25 +05:30
parent 4df9e10516
commit 9f12e8ec68
1 changed files with 3 additions and 3 deletions

View File

@ -3,7 +3,7 @@ using Flux.Tracker: TrackedArray, data
@testset "CUDNN BatchNorm" begin
@testset "4D Input" begin
x = TrackedArray(rand(10, 10, 3, 1))
x = TrackedArray(Float64.(collect(reshape(1:12, 2, 2, 3, 1))))
m = BatchNorm(3)
cx = gpu(x)
cm = gpu(m)
@ -23,9 +23,9 @@ using Flux.Tracker: TrackedArray, data
@test m.β.grad cpu(cm.β.grad)
@test x.grad cpu(x.grad)
end
@testset "2D Input" begin
x = TrackedArray(rand(3, 1))
x = TrackedArray(Float64.(collect(reshape(1:12, 3, 4))))
m = BatchNorm(3)
cx = gpu(x)
cm = gpu(m)