Make the test more reliable
This commit is contained in:
parent
4df9e10516
commit
9f12e8ec68
@ -3,7 +3,7 @@ using Flux.Tracker: TrackedArray, data
|
|||||||
|
|
||||||
@testset "CUDNN BatchNorm" begin
|
@testset "CUDNN BatchNorm" begin
|
||||||
@testset "4D Input" begin
|
@testset "4D Input" begin
|
||||||
x = TrackedArray(rand(10, 10, 3, 1))
|
x = TrackedArray(Float64.(collect(reshape(1:12, 2, 2, 3, 1))))
|
||||||
m = BatchNorm(3)
|
m = BatchNorm(3)
|
||||||
cx = gpu(x)
|
cx = gpu(x)
|
||||||
cm = gpu(m)
|
cm = gpu(m)
|
||||||
@ -23,9 +23,9 @@ using Flux.Tracker: TrackedArray, data
|
|||||||
@test m.β.grad ≈ cpu(cm.β.grad)
|
@test m.β.grad ≈ cpu(cm.β.grad)
|
||||||
@test x.grad ≈ cpu(x.grad)
|
@test x.grad ≈ cpu(x.grad)
|
||||||
end
|
end
|
||||||
|
|
||||||
@testset "2D Input" begin
|
@testset "2D Input" begin
|
||||||
x = TrackedArray(rand(3, 1))
|
x = TrackedArray(Float64.(collect(reshape(1:12, 3, 4))))
|
||||||
m = BatchNorm(3)
|
m = BatchNorm(3)
|
||||||
cx = gpu(x)
|
cx = gpu(x)
|
||||||
cm = gpu(m)
|
cm = gpu(m)
|
||||||
|
Loading…
Reference in New Issue
Block a user