Flux.jl/test/cuda/cudnn.jl

35 lines
847 B
Julia
Raw Normal View History

2018-01-30 13:12:33 +00:00
using Flux, CuArrays, Base.Test
using Flux.CUDA
2018-01-31 13:46:55 +00:00
using Flux.CUDA: RNNDesc
using CUDAnative
2018-01-30 13:12:33 +00:00
info("Testing Flux/CUDNN")
2018-02-08 02:37:55 +00:00
@testset "RNN" begin
@testset for R in [RNN, GRU, LSTM]
x = param(rand(10,5))
cux = cu(x)
rnn = R(10, 5)
curnn = mapleaves(cu, rnn)
y = rnn(x)
cuy = curnn(cux)
@test y.data collect(cuy.data)
@test haskey(Flux.CUDA.descs, curnn.cell)
Δ = randn(size(y))
Flux.back!(y, Δ)
Flux.back!(cuy, cu(Δ))
@test x.grad collect(cux.grad)
@test rnn.cell.Wi.grad collect(curnn.cell.Wi.grad)
@test rnn.cell.Wh.grad collect(curnn.cell.Wh.grad)
@test rnn.cell.b.grad collect(curnn.cell.b.grad)
@test rnn.cell.h.grad collect(curnn.cell.h.grad)
if isdefined(rnn.cell, :c)
@test rnn.cell.c.grad collect(curnn.cell.c.grad)
end
2018-01-30 13:12:33 +00:00
end
end