Merge pull request #1 from SimonDanisch/patch-1

fix copy_transpose! for cuda
This commit is contained in:
Josh Christie 2018-08-15 12:18:22 +02:00 committed by GitHub
commit a3ab1cbb98
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,6 +1,8 @@
using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, libcudnn_handle, using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, libcudnn_handle,
cudnnDataType, TensorDesc, FilterDesc cudnnDataType, TensorDesc, FilterDesc
using LinearAlgebra
mutable struct DropoutDesc mutable struct DropoutDesc
ptr::Ptr{Nothing} ptr::Ptr{Nothing}
states::CuVector{UInt8} states::CuVector{UInt8}
@ -244,14 +246,14 @@ import ..Tracker: TrackedArray
using CUDAnative using CUDAnative
using CuArrays: @cuindex, cudims using CuArrays: @cuindex, cudims
function copy_transpose!(dst::CuArray, src::CuArray) function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
function kernel(dst, src) function kernel(dst, src)
I = @cuindex dst I = @cuindex dst
dst[I...] = src[reverse(I)...] dst[I...] = src[reverse(I)...]
return return
end end
blk, thr = cudims(dst) blk, thr = cudims(dst)
@cuda (blk, thr) kernel(dst, src) @cuda blocks=blk threads=thr kernel(dst, src)
return dst return dst
end end