Add tests for CuDNN BatchNorm
This commit is contained in:
parent
91850a8baf
commit
f29377123e
@ -32,4 +32,8 @@ cx = gpu(x)
|
|||||||
|
|
||||||
end
|
end
|
||||||
|
|
||||||
CuArrays.cudnn_available() && include("cudnn.jl")
|
if CuArrays.cudnn_available()
|
||||||
|
info("Testing Flux/CUDNN RNN")
|
||||||
|
include("cudnn.jl")
|
||||||
|
include("curnn.jl")
|
||||||
|
end
|
||||||
|
@ -1,48 +1,8 @@
|
|||||||
using Flux, CuArrays, Base.Test
|
using Flux, Flux.Tracker, CuArrays, Base.Test
|
||||||
|
using Flux: gpu
|
||||||
|
|
||||||
info("Testing Flux/CUDNN")
|
@testset "CUDNN BatchNorm" begin
|
||||||
|
x = gpu(rand(10, 10, 3, 1))
|
||||||
@testset "RNN" begin
|
m = gpu(BatchNorm(3))
|
||||||
@testset for R in [RNN, GRU, LSTM]
|
@test m(x) isa TrackedArray{Float32,4,CuArray{Float32,4}}
|
||||||
rnn = R(10, 5)
|
|
||||||
curnn = mapleaves(gpu, rnn)
|
|
||||||
@testset for batch_size in (1, 5)
|
|
||||||
Flux.reset!(rnn)
|
|
||||||
Flux.reset!(curnn)
|
|
||||||
x = batch_size == 1 ?
|
|
||||||
param(rand(10)) :
|
|
||||||
param(rand(10,batch_size))
|
|
||||||
cux = gpu(x)
|
|
||||||
y = (rnn(x); rnn(x))
|
|
||||||
cuy = (curnn(cux); curnn(cux))
|
|
||||||
|
|
||||||
@test y.data ≈ collect(cuy.data)
|
|
||||||
@test haskey(Flux.CUDA.descs, curnn.cell)
|
|
||||||
|
|
||||||
Δ = randn(size(y))
|
|
||||||
|
|
||||||
Flux.back!(y, Δ)
|
|
||||||
Flux.back!(cuy, gpu(Δ))
|
|
||||||
|
|
||||||
@test x.grad ≈ collect(cux.grad)
|
|
||||||
@test rnn.cell.Wi.grad ≈ collect(curnn.cell.Wi.grad)
|
|
||||||
@test rnn.cell.Wh.grad ≈ collect(curnn.cell.Wh.grad)
|
|
||||||
@test rnn.cell.b.grad ≈ collect(curnn.cell.b.grad)
|
|
||||||
@test rnn.cell.h.grad ≈ collect(curnn.cell.h.grad)
|
|
||||||
if isdefined(rnn.cell, :c)
|
|
||||||
@test rnn.cell.c.grad ≈ collect(curnn.cell.c.grad)
|
|
||||||
end
|
|
||||||
|
|
||||||
Flux.reset!(rnn)
|
|
||||||
Flux.reset!(curnn)
|
|
||||||
ohx = batch_size == 1 ?
|
|
||||||
Flux.onehot(rand(1:10), 1:10) :
|
|
||||||
Flux.onehotbatch(rand(1:10, batch_size), 1:10)
|
|
||||||
cuohx = gpu(ohx)
|
|
||||||
y = (rnn(ohx); rnn(ohx))
|
|
||||||
cuy = (curnn(cuohx); curnn(cuohx))
|
|
||||||
|
|
||||||
@test y.data ≈ collect(cuy.data)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
46
test/cuda/curnn.jl
Normal file
46
test/cuda/curnn.jl
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
using Flux, CuArrays, Base.Test
|
||||||
|
|
||||||
|
@testset "RNN" begin
|
||||||
|
@testset for R in [RNN, GRU, LSTM]
|
||||||
|
rnn = R(10, 5)
|
||||||
|
curnn = mapleaves(gpu, rnn)
|
||||||
|
@testset for batch_size in (1, 5)
|
||||||
|
Flux.reset!(rnn)
|
||||||
|
Flux.reset!(curnn)
|
||||||
|
x = batch_size == 1 ?
|
||||||
|
param(rand(10)) :
|
||||||
|
param(rand(10,batch_size))
|
||||||
|
cux = gpu(x)
|
||||||
|
y = (rnn(x); rnn(x))
|
||||||
|
cuy = (curnn(cux); curnn(cux))
|
||||||
|
|
||||||
|
@test y.data ≈ collect(cuy.data)
|
||||||
|
@test haskey(Flux.CUDA.descs, curnn.cell)
|
||||||
|
|
||||||
|
Δ = randn(size(y))
|
||||||
|
|
||||||
|
Flux.back!(y, Δ)
|
||||||
|
Flux.back!(cuy, gpu(Δ))
|
||||||
|
|
||||||
|
@test x.grad ≈ collect(cux.grad)
|
||||||
|
@test rnn.cell.Wi.grad ≈ collect(curnn.cell.Wi.grad)
|
||||||
|
@test rnn.cell.Wh.grad ≈ collect(curnn.cell.Wh.grad)
|
||||||
|
@test rnn.cell.b.grad ≈ collect(curnn.cell.b.grad)
|
||||||
|
@test rnn.cell.h.grad ≈ collect(curnn.cell.h.grad)
|
||||||
|
if isdefined(rnn.cell, :c)
|
||||||
|
@test rnn.cell.c.grad ≈ collect(curnn.cell.c.grad)
|
||||||
|
end
|
||||||
|
|
||||||
|
Flux.reset!(rnn)
|
||||||
|
Flux.reset!(curnn)
|
||||||
|
ohx = batch_size == 1 ?
|
||||||
|
Flux.onehot(rand(1:10), 1:10) :
|
||||||
|
Flux.onehotbatch(rand(1:10, batch_size), 1:10)
|
||||||
|
cuohx = gpu(ohx)
|
||||||
|
y = (rnn(ohx); rnn(ohx))
|
||||||
|
cuy = (curnn(cuohx); curnn(cuohx))
|
||||||
|
|
||||||
|
@test y.data ≈ collect(cuy.data)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
Loading…
Reference in New Issue
Block a user