Flux.jl/test/cuda/cudnn.jl

25 lines
533 B
Julia
Raw Normal View History

2018-06-22 12:49:18 +00:00
using Flux, Flux.Tracker, CuArrays, Base.Test
2018-07-04 13:27:43 +00:00
using Flux.Tracker: TrackedArray, data
2018-01-30 13:12:33 +00:00
2018-06-22 12:49:18 +00:00
@testset "CUDNN BatchNorm" begin
x = TrackedArray(rand(10, 10, 3, 1))
m = BatchNorm(3)
cx = gpu(x)
cm = gpu(m)
y = m(x)
cy = cm(cx)
@test cy isa TrackedArray{Float32,4,CuArray{Float32,4}}
2018-07-04 13:27:43 +00:00
@test cpu(data(cy)) data(y)
2018-07-04 13:27:43 +00:00
g = ones(size(y)...)
Flux.back!(y, g)
2018-07-17 04:23:39 +00:00
Flux.back!(cy, gpu(g))
@test m.γ.grad cpu(cm.γ.grad)
@test m.β.grad cpu(cm.β.grad)
2018-07-04 13:27:43 +00:00
@test x.grad cpu(x.grad)
2018-01-30 13:12:33 +00:00
end