diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index decc91ae..e6ff4068 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -325,6 +325,10 @@ function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union return (result[2], result[3]), result[1] end +(m::CuRNN{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) +(m::CuGRU{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) +(m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) + function accum_transpose!(dst::CuArray, src::CuArray) function kernel(dst, src) I = @cuindex dst diff --git a/test/cuda/cudnn.jl b/test/cuda/cudnn.jl index 85ab21c8..c29b5ef8 100644 --- a/test/cuda/cudnn.jl +++ b/test/cuda/cudnn.jl @@ -4,28 +4,45 @@ info("Testing Flux/CUDNN") @testset "RNN" begin @testset for R in [RNN, GRU, LSTM] - x = param(rand(10,5)) - cux = cu(x) rnn = R(10, 5) curnn = mapleaves(cu, rnn) - y = (rnn(x); rnn(x)) - cuy = (curnn(cux); curnn(cux)) + @testset for batch_size in (1, 5) + Flux.reset!(rnn) + Flux.reset!(curnn) + x = batch_size == 1 ? + param(rand(10)) : + param(rand(10,batch_size)) + cux = cu(x) + y = (rnn(x); rnn(x)) + cuy = (curnn(cux); curnn(cux)) - @test y.data ≈ collect(cuy.data) - @test haskey(Flux.CUDA.descs, curnn.cell) + @test y.data ≈ collect(cuy.data) + @test haskey(Flux.CUDA.descs, curnn.cell) - Δ = randn(size(y)) + Δ = randn(size(y)) - Flux.back!(y, Δ) - Flux.back!(cuy, cu(Δ)) + Flux.back!(y, Δ) + Flux.back!(cuy, cu(Δ)) - @test x.grad ≈ collect(cux.grad) - @test rnn.cell.Wi.grad ≈ collect(curnn.cell.Wi.grad) - @test rnn.cell.Wh.grad ≈ collect(curnn.cell.Wh.grad) - @test rnn.cell.b.grad ≈ collect(curnn.cell.b.grad) - @test rnn.cell.h.grad ≈ collect(curnn.cell.h.grad) - if isdefined(rnn.cell, :c) - @test rnn.cell.c.grad ≈ collect(curnn.cell.c.grad) + @test x.grad ≈ collect(cux.grad) + @test rnn.cell.Wi.grad ≈ collect(curnn.cell.Wi.grad) + @test rnn.cell.Wh.grad ≈ collect(curnn.cell.Wh.grad) + @test rnn.cell.b.grad ≈ collect(curnn.cell.b.grad) + @test rnn.cell.h.grad ≈ collect(curnn.cell.h.grad) + if isdefined(rnn.cell, :c) + @test rnn.cell.c.grad ≈ collect(curnn.cell.c.grad) + end + + Flux.reset!(rnn) + Flux.reset!(curnn) + ohx = batch_size == 1 ? + Flux.onehot(rand(1:10), 1:10) : + Flux.onehotbatch(rand(1:10, batch_size), 1:10) + cuohx = cu(ohx) + y = (rnn(ohx); rnn(ohx)) + cuy = (curnn(cuohx); curnn(cuohx)) + + @test y.data ≈ collect(cuy.data) end end end