Update tests

This commit is contained in:
Avik Pal 2018-07-04 18:57:43 +05:30
parent e3b10691d2
commit b239fc684e
2 changed files with 8 additions and 7 deletions

View File

@ -33,7 +33,8 @@ cx = gpu(x)
end end
if CuArrays.cudnn_available() if CuArrays.cudnn_available()
info("Testing Flux/CUDNN RNN") info("Testing Flux/CUDNN BatchNorm")
include("cudnn.jl") include("cudnn.jl")
info("Testing Flux/CUDNN RNN")
include("curnn.jl") include("curnn.jl")
end end

View File

@ -1,6 +1,5 @@
using Flux, Flux.Tracker, CuArrays, Base.Test using Flux, Flux.Tracker, CuArrays, Base.Test
using Flux.Tracker: TrackedArray using Flux.Tracker: TrackedArray, data
using Flux: gpu
@testset "CUDNN BatchNorm" begin @testset "CUDNN BatchNorm" begin
x = TrackedArray(rand(10, 10, 3, 1)) x = TrackedArray(rand(10, 10, 3, 1))
@ -13,12 +12,13 @@ using Flux: gpu
@test cy isa TrackedArray{Float32,4,CuArray{Float32,4}} @test cy isa TrackedArray{Float32,4,CuArray{Float32,4}}
@test cpu(cy) y @test cpu(data(cy)) data(y)
Flux.back!(y, ones(y)) g = ones(size(y)...)
Flux.back!(cy, ones(cy)) Flux.back!(y, g)
Flux.back!(cy, gpu(g)))
@test m.γ.grad cpu(cm.γ.grad) @test m.γ.grad cpu(cm.γ.grad)
@test m.β.grad cpu(cm.β.grad) @test m.β.grad cpu(cm.β.grad)
@test m.x.grad cpu(cm.x.grad) @test x.grad cpu(x.grad)
end end