Flux.jl/test/cuda/cudnn.jl

49 lines
1.3 KiB
Julia
Raw Normal View History

2018-07-18 13:39:20 +00:00
using Flux, CuArrays, Test
2018-01-30 13:12:33 +00:00
2018-08-11 13:42:33 +00:00
@info "Testing Flux/CUDNN"
2018-01-30 13:12:33 +00:00
2018-02-08 02:37:55 +00:00
@testset "RNN" begin
@testset for R in [RNN, GRU, LSTM]
rnn = R(10, 5)
2018-02-28 22:51:08 +00:00
curnn = mapleaves(gpu, rnn)
@testset for batch_size in (1, 5)
Flux.reset!(rnn)
Flux.reset!(curnn)
x = batch_size == 1 ?
param(rand(10)) :
param(rand(10,batch_size))
2018-02-28 22:51:08 +00:00
cux = gpu(x)
y = (rnn(x); rnn(x))
cuy = (curnn(cux); curnn(cux))
2018-02-08 02:37:55 +00:00
@test y.data collect(cuy.data)
@test haskey(Flux.CUDA.descs, curnn.cell)
2018-02-08 02:37:55 +00:00
Δ = randn(size(y))
2018-02-08 02:37:55 +00:00
Flux.back!(y, Δ)
2018-02-28 22:51:08 +00:00
Flux.back!(cuy, gpu(Δ))
2018-02-08 02:37:55 +00:00
@test x.grad collect(cux.grad)
@test rnn.cell.Wi.grad collect(curnn.cell.Wi.grad)
@test rnn.cell.Wh.grad collect(curnn.cell.Wh.grad)
@test rnn.cell.b.grad collect(curnn.cell.b.grad)
@test rnn.cell.h.grad collect(curnn.cell.h.grad)
if isdefined(rnn.cell, :c)
@test rnn.cell.c.grad collect(curnn.cell.c.grad)
end
Flux.reset!(rnn)
Flux.reset!(curnn)
ohx = batch_size == 1 ?
Flux.onehot(rand(1:10), 1:10) :
Flux.onehotbatch(rand(1:10, batch_size), 1:10)
2018-02-28 22:51:08 +00:00
cuohx = gpu(ohx)
y = (rnn(ohx); rnn(ohx))
cuy = (curnn(cuohx); curnn(cuohx))
@test y.data collect(cuy.data)
2018-02-08 02:37:55 +00:00
end
2018-01-30 13:12:33 +00:00
end
end