Update tests
This commit is contained in:
parent
e3b10691d2
commit
b239fc684e
@ -33,7 +33,8 @@ cx = gpu(x)
|
|||||||
end
|
end
|
||||||
|
|
||||||
if CuArrays.cudnn_available()
|
if CuArrays.cudnn_available()
|
||||||
info("Testing Flux/CUDNN RNN")
|
info("Testing Flux/CUDNN BatchNorm")
|
||||||
include("cudnn.jl")
|
include("cudnn.jl")
|
||||||
|
info("Testing Flux/CUDNN RNN")
|
||||||
include("curnn.jl")
|
include("curnn.jl")
|
||||||
end
|
end
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
using Flux, Flux.Tracker, CuArrays, Base.Test
|
using Flux, Flux.Tracker, CuArrays, Base.Test
|
||||||
using Flux.Tracker: TrackedArray
|
using Flux.Tracker: TrackedArray, data
|
||||||
using Flux: gpu
|
|
||||||
|
|
||||||
@testset "CUDNN BatchNorm" begin
|
@testset "CUDNN BatchNorm" begin
|
||||||
x = TrackedArray(rand(10, 10, 3, 1))
|
x = TrackedArray(rand(10, 10, 3, 1))
|
||||||
@ -13,12 +12,13 @@ using Flux: gpu
|
|||||||
|
|
||||||
@test cy isa TrackedArray{Float32,4,CuArray{Float32,4}}
|
@test cy isa TrackedArray{Float32,4,CuArray{Float32,4}}
|
||||||
|
|
||||||
@test cpu(cy) ≈ y
|
@test cpu(data(cy)) ≈ data(y)
|
||||||
|
|
||||||
Flux.back!(y, ones(y))
|
g = ones(size(y)...)
|
||||||
Flux.back!(cy, ones(cy))
|
Flux.back!(y, g)
|
||||||
|
Flux.back!(cy, gpu(g)))
|
||||||
|
|
||||||
@test m.γ.grad ≈ cpu(cm.γ.grad)
|
@test m.γ.grad ≈ cpu(cm.γ.grad)
|
||||||
@test m.β.grad ≈ cpu(cm.β.grad)
|
@test m.β.grad ≈ cpu(cm.β.grad)
|
||||||
@test m.x.grad ≈ cpu(cm.x.grad)
|
@test x.grad ≈ cpu(x.grad)
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user