From e6f556411aa54d685d4f749f0e6bac4143707946 Mon Sep 17 00:00:00 2001 From: Ed Schmerling Date: Mon, 19 Feb 2018 17:32:15 -0800 Subject: [PATCH 1/2] Convert OneHot CuArrays to dense CuArrays before passing to CUDNN methods --- src/cuda/cudnn.jl | 4 ++++ src/onehot.jl | 1 + test/cuda/cudnn.jl | 49 +++++++++++++++++++++++++++++++--------------- 3 files changed, 38 insertions(+), 16 deletions(-) diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index decc91ae..e6ff4068 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -325,6 +325,10 @@ function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union return (result[2], result[3]), result[1] end +(m::CuRNN{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) +(m::CuGRU{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) +(m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) + function accum_transpose!(dst::CuArray, src::CuArray) function kernel(dst, src) I = @cuindex dst diff --git a/src/onehot.jl b/src/onehot.jl index 07206dfe..608e3549 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -36,6 +36,7 @@ adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data)) import CuArrays: CuArray, cudaconvert Base.Broadcast._containertype(::Type{<:OneHotMatrix{<:CuArray}}) = CuArray cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data)) + (::Type{<:CuArray{T}})(x::OneHotMatrix{<:CuArray}) where {T} = broadcast(y -> T(y), x) end function onehot(l, labels) diff --git a/test/cuda/cudnn.jl b/test/cuda/cudnn.jl index 85ab21c8..c29b5ef8 100644 --- a/test/cuda/cudnn.jl +++ b/test/cuda/cudnn.jl @@ -4,28 +4,45 @@ info("Testing Flux/CUDNN") @testset "RNN" begin @testset for R in [RNN, GRU, LSTM] - x = param(rand(10,5)) - cux = cu(x) rnn = R(10, 5) curnn = mapleaves(cu, rnn) - y = (rnn(x); rnn(x)) - cuy = (curnn(cux); curnn(cux)) + @testset for batch_size in (1, 5) + Flux.reset!(rnn) + Flux.reset!(curnn) + x = batch_size == 1 ? + param(rand(10)) : + param(rand(10,batch_size)) + cux = cu(x) + y = (rnn(x); rnn(x)) + cuy = (curnn(cux); curnn(cux)) - @test y.data ≈ collect(cuy.data) - @test haskey(Flux.CUDA.descs, curnn.cell) + @test y.data ≈ collect(cuy.data) + @test haskey(Flux.CUDA.descs, curnn.cell) - Δ = randn(size(y)) + Δ = randn(size(y)) - Flux.back!(y, Δ) - Flux.back!(cuy, cu(Δ)) + Flux.back!(y, Δ) + Flux.back!(cuy, cu(Δ)) - @test x.grad ≈ collect(cux.grad) - @test rnn.cell.Wi.grad ≈ collect(curnn.cell.Wi.grad) - @test rnn.cell.Wh.grad ≈ collect(curnn.cell.Wh.grad) - @test rnn.cell.b.grad ≈ collect(curnn.cell.b.grad) - @test rnn.cell.h.grad ≈ collect(curnn.cell.h.grad) - if isdefined(rnn.cell, :c) - @test rnn.cell.c.grad ≈ collect(curnn.cell.c.grad) + @test x.grad ≈ collect(cux.grad) + @test rnn.cell.Wi.grad ≈ collect(curnn.cell.Wi.grad) + @test rnn.cell.Wh.grad ≈ collect(curnn.cell.Wh.grad) + @test rnn.cell.b.grad ≈ collect(curnn.cell.b.grad) + @test rnn.cell.h.grad ≈ collect(curnn.cell.h.grad) + if isdefined(rnn.cell, :c) + @test rnn.cell.c.grad ≈ collect(curnn.cell.c.grad) + end + + Flux.reset!(rnn) + Flux.reset!(curnn) + ohx = batch_size == 1 ? + Flux.onehot(rand(1:10), 1:10) : + Flux.onehotbatch(rand(1:10, batch_size), 1:10) + cuohx = cu(ohx) + y = (rnn(ohx); rnn(ohx)) + cuy = (curnn(cuohx); curnn(cuohx)) + + @test y.data ≈ collect(cuy.data) end end end From 6bdd283fbd1a7bfc1f9e5c96adb20f27d27ed4b9 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 21 Feb 2018 22:29:24 +0000 Subject: [PATCH 2/2] no longer necessary --- src/onehot.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/onehot.jl b/src/onehot.jl index 608e3549..07206dfe 100644 --- a/src/onehot.jl +++ b/src/onehot.jl @@ -36,7 +36,6 @@ adapt(T, xs::OneHotMatrix) = OneHotMatrix(xs.height, adapt(T, xs.data)) import CuArrays: CuArray, cudaconvert Base.Broadcast._containertype(::Type{<:OneHotMatrix{<:CuArray}}) = CuArray cudaconvert(x::OneHotMatrix{<:CuArray}) = OneHotMatrix(x.height, cudaconvert(x.data)) - (::Type{<:CuArray{T}})(x::OneHotMatrix{<:CuArray}) where {T} = broadcast(y -> T(y), x) end function onehot(l, labels)