Update tests

This commit is contained in:
Avik Pal 2018-07-04 18:57:43 +05:30
parent e3b10691d2
commit b239fc684e
2 changed files with 8 additions and 7 deletions

View File

@ -33,7 +33,8 @@ cx = gpu(x)
end
if CuArrays.cudnn_available()
info("Testing Flux/CUDNN RNN")
info("Testing Flux/CUDNN BatchNorm")
include("cudnn.jl")
info("Testing Flux/CUDNN RNN")
include("curnn.jl")
end

View File

@ -1,6 +1,5 @@
using Flux, Flux.Tracker, CuArrays, Base.Test
using Flux.Tracker: TrackedArray
using Flux: gpu
using Flux.Tracker: TrackedArray, data
@testset "CUDNN BatchNorm" begin
x = TrackedArray(rand(10, 10, 3, 1))
@ -13,12 +12,13 @@ using Flux: gpu
@test cy isa TrackedArray{Float32,4,CuArray{Float32,4}}
@test cpu(cy) y
@test cpu(data(cy)) data(y)
Flux.back!(y, ones(y))
Flux.back!(cy, ones(cy))
g = ones(size(y)...)
Flux.back!(y, g)
Flux.back!(cy, gpu(g)))
@test m.γ.grad cpu(cm.γ.grad)
@test m.β.grad cpu(cm.β.grad)
@test m.x.grad cpu(cm.x.grad)
@test x.grad cpu(x.grad)
end