Add tests to make sure CPU and GPU versions have similar outputs
This commit is contained in:
parent
24ba1c4e6c
commit
9a168528de
|
@ -1,8 +1,24 @@
|
|||
using Flux, Flux.Tracker, CuArrays, Base.Test
|
||||
using Flux.Tracker: TrackedArray
|
||||
using Flux: gpu
|
||||
|
||||
@testset "CUDNN BatchNorm" begin
|
||||
x = gpu(rand(10, 10, 3, 1))
|
||||
m = gpu(BatchNorm(3))
|
||||
@test m(x) isa TrackedArray{Float32,4,CuArray{Float32,4}}
|
||||
x = TrackedArray(rand(10, 10, 3, 1))
|
||||
m = BatchNorm(3)
|
||||
cx = gpu(x)
|
||||
cm = gpu(m)
|
||||
|
||||
y = m(x)
|
||||
cy = cm(cx)
|
||||
|
||||
@test cy isa TrackedArray{Float32,4,CuArray{Float32,4}}
|
||||
|
||||
@test cpu(cy) ≈ y
|
||||
|
||||
Flux.back!(y, ones(y))
|
||||
Flux.back!(cy, ones(cy))
|
||||
|
||||
@test m.γ.grad ≈ cpu(cm.γ.grad)
|
||||
@test m.β.grad ≈ cpu(cm.β.grad)
|
||||
@test m.x.grad ≈ cpu(cm.x.grad)
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue