Merge pull request #178 from schmrlng/pull-request/e6f55641
Convert OneHot CuArrays to dense CuArrays before passing to CUDNN methods
This commit is contained in:
commit
e3b4b16e01
@ -325,6 +325,10 @@ function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union
|
|||||||
return (result[2], result[3]), result[1]
|
return (result[2], result[3]), result[1]
|
||||||
end
|
end
|
||||||
|
|
||||||
|
(m::CuRNN{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||||
|
(m::CuGRU{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||||
|
(m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
|
||||||
|
|
||||||
function accum_transpose!(dst::CuArray, src::CuArray)
|
function accum_transpose!(dst::CuArray, src::CuArray)
|
||||||
function kernel(dst, src)
|
function kernel(dst, src)
|
||||||
I = @cuindex dst
|
I = @cuindex dst
|
||||||
|
@ -4,10 +4,15 @@ info("Testing Flux/CUDNN")
|
|||||||
|
|
||||||
@testset "RNN" begin
|
@testset "RNN" begin
|
||||||
@testset for R in [RNN, GRU, LSTM]
|
@testset for R in [RNN, GRU, LSTM]
|
||||||
x = param(rand(10,5))
|
|
||||||
cux = cu(x)
|
|
||||||
rnn = R(10, 5)
|
rnn = R(10, 5)
|
||||||
curnn = mapleaves(cu, rnn)
|
curnn = mapleaves(cu, rnn)
|
||||||
|
@testset for batch_size in (1, 5)
|
||||||
|
Flux.reset!(rnn)
|
||||||
|
Flux.reset!(curnn)
|
||||||
|
x = batch_size == 1 ?
|
||||||
|
param(rand(10)) :
|
||||||
|
param(rand(10,batch_size))
|
||||||
|
cux = cu(x)
|
||||||
y = (rnn(x); rnn(x))
|
y = (rnn(x); rnn(x))
|
||||||
cuy = (curnn(cux); curnn(cux))
|
cuy = (curnn(cux); curnn(cux))
|
||||||
|
|
||||||
@ -27,5 +32,17 @@ info("Testing Flux/CUDNN")
|
|||||||
if isdefined(rnn.cell, :c)
|
if isdefined(rnn.cell, :c)
|
||||||
@test rnn.cell.c.grad ≈ collect(curnn.cell.c.grad)
|
@test rnn.cell.c.grad ≈ collect(curnn.cell.c.grad)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Flux.reset!(rnn)
|
||||||
|
Flux.reset!(curnn)
|
||||||
|
ohx = batch_size == 1 ?
|
||||||
|
Flux.onehot(rand(1:10), 1:10) :
|
||||||
|
Flux.onehotbatch(rand(1:10, batch_size), 1:10)
|
||||||
|
cuohx = cu(ohx)
|
||||||
|
y = (rnn(ohx); rnn(ohx))
|
||||||
|
cuy = (curnn(cuohx); curnn(cuohx))
|
||||||
|
|
||||||
|
@test y.data ≈ collect(cuy.data)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user