Add test for 2D inputs
This commit is contained in:
parent
d6aacf4135
commit
4df9e10516
@ -2,23 +2,47 @@ using Flux, Flux.Tracker, CuArrays, Test
|
|||||||
using Flux.Tracker: TrackedArray, data
|
using Flux.Tracker: TrackedArray, data
|
||||||
|
|
||||||
@testset "CUDNN BatchNorm" begin
|
@testset "CUDNN BatchNorm" begin
|
||||||
x = TrackedArray(rand(10, 10, 3, 1))
|
@testset "4D Input" begin
|
||||||
m = BatchNorm(3)
|
x = TrackedArray(rand(10, 10, 3, 1))
|
||||||
cx = gpu(x)
|
m = BatchNorm(3)
|
||||||
cm = gpu(m)
|
cx = gpu(x)
|
||||||
|
cm = gpu(m)
|
||||||
|
|
||||||
y = m(x)
|
y = m(x)
|
||||||
cy = cm(cx)
|
cy = cm(cx)
|
||||||
|
|
||||||
@test cy isa TrackedArray{Float32,4,CuArray{Float32,4}}
|
@test cy isa TrackedArray{Float32,4,CuArray{Float32,4}}
|
||||||
|
|
||||||
@test cpu(data(cy)) ≈ data(y)
|
@test cpu(data(cy)) ≈ data(y)
|
||||||
|
|
||||||
g = rand(size(y)...)
|
g = rand(size(y)...)
|
||||||
Flux.back!(y, g)
|
Flux.back!(y, g)
|
||||||
Flux.back!(cy, gpu(g))
|
Flux.back!(cy, gpu(g))
|
||||||
|
|
||||||
@test m.γ.grad ≈ cpu(cm.γ.grad)
|
@test m.γ.grad ≈ cpu(cm.γ.grad)
|
||||||
@test m.β.grad ≈ cpu(cm.β.grad)
|
@test m.β.grad ≈ cpu(cm.β.grad)
|
||||||
@test x.grad ≈ cpu(x.grad)
|
@test x.grad ≈ cpu(x.grad)
|
||||||
|
end
|
||||||
|
|
||||||
|
@testset "2D Input" begin
|
||||||
|
x = TrackedArray(rand(3, 1))
|
||||||
|
m = BatchNorm(3)
|
||||||
|
cx = gpu(x)
|
||||||
|
cm = gpu(m)
|
||||||
|
|
||||||
|
y = m(x)
|
||||||
|
cy = cm(cx)
|
||||||
|
|
||||||
|
@test cy isa TrackedArray{Float32,2,CuArray{Float32,2}}
|
||||||
|
|
||||||
|
@test cpu(data(cy)) ≈ data(y)
|
||||||
|
|
||||||
|
g = rand(size(y)...)
|
||||||
|
Flux.back!(y, g)
|
||||||
|
Flux.back!(cy, gpu(g))
|
||||||
|
|
||||||
|
@test m.γ.grad ≈ cpu(cm.γ.grad)
|
||||||
|
@test m.β.grad ≈ cpu(cm.β.grad)
|
||||||
|
@test x.grad ≈ cpu(x.grad)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user