2018-08-11 09:32:55 +00:00
|
|
|
|
using Flux, Flux.Tracker, CuArrays, Test
|
2018-07-04 13:27:43 +00:00
|
|
|
|
using Flux.Tracker: TrackedArray, data
|
2018-01-30 13:12:33 +00:00
|
|
|
|
|
2018-06-22 12:49:18 +00:00
|
|
|
|
@testset "CUDNN BatchNorm" begin
|
2018-11-10 06:22:23 +00:00
|
|
|
|
@testset "4D Input" begin
|
2018-11-10 08:30:25 +00:00
|
|
|
|
x = TrackedArray(Float64.(collect(reshape(1:12, 2, 2, 3, 1))))
|
2018-11-10 06:22:23 +00:00
|
|
|
|
m = BatchNorm(3)
|
|
|
|
|
cx = gpu(x)
|
|
|
|
|
cm = gpu(m)
|
2018-06-23 05:33:15 +00:00
|
|
|
|
|
2018-11-10 06:22:23 +00:00
|
|
|
|
y = m(x)
|
|
|
|
|
cy = cm(cx)
|
2018-06-23 05:33:15 +00:00
|
|
|
|
|
2018-11-10 06:22:23 +00:00
|
|
|
|
@test cy isa TrackedArray{Float32,4,CuArray{Float32,4}}
|
2018-06-23 05:33:15 +00:00
|
|
|
|
|
2018-11-10 06:22:23 +00:00
|
|
|
|
@test cpu(data(cy)) ≈ data(y)
|
2018-06-23 05:33:15 +00:00
|
|
|
|
|
2018-11-10 06:22:23 +00:00
|
|
|
|
g = rand(size(y)...)
|
|
|
|
|
Flux.back!(y, g)
|
|
|
|
|
Flux.back!(cy, gpu(g))
|
2018-06-23 05:33:15 +00:00
|
|
|
|
|
2018-11-10 06:22:23 +00:00
|
|
|
|
@test m.γ.grad ≈ cpu(cm.γ.grad)
|
|
|
|
|
@test m.β.grad ≈ cpu(cm.β.grad)
|
|
|
|
|
@test x.grad ≈ cpu(x.grad)
|
|
|
|
|
end
|
2018-11-10 08:30:25 +00:00
|
|
|
|
|
2018-11-10 06:22:23 +00:00
|
|
|
|
@testset "2D Input" begin
|
2018-11-10 08:30:25 +00:00
|
|
|
|
x = TrackedArray(Float64.(collect(reshape(1:12, 3, 4))))
|
2018-11-10 06:22:23 +00:00
|
|
|
|
m = BatchNorm(3)
|
|
|
|
|
cx = gpu(x)
|
|
|
|
|
cm = gpu(m)
|
|
|
|
|
|
|
|
|
|
y = m(x)
|
|
|
|
|
cy = cm(cx)
|
|
|
|
|
|
|
|
|
|
@test cy isa TrackedArray{Float32,2,CuArray{Float32,2}}
|
|
|
|
|
|
|
|
|
|
@test cpu(data(cy)) ≈ data(y)
|
|
|
|
|
|
|
|
|
|
g = rand(size(y)...)
|
|
|
|
|
Flux.back!(y, g)
|
|
|
|
|
Flux.back!(cy, gpu(g))
|
|
|
|
|
|
|
|
|
|
@test m.γ.grad ≈ cpu(cm.γ.grad)
|
|
|
|
|
@test m.β.grad ≈ cpu(cm.β.grad)
|
|
|
|
|
@test x.grad ≈ cpu(x.grad)
|
|
|
|
|
end
|
2018-01-30 13:12:33 +00:00
|
|
|
|
end
|