fix copy_transpose!

This commit is contained in:
Simon 2018-08-15 12:16:12 +02:00 committed by GitHub
parent 4683e925d4
commit a43127f881
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,6 +1,8 @@
using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, libcudnn_handle,
cudnnDataType, TensorDesc, FilterDesc
using LinearAlgebra
mutable struct DropoutDesc
ptr::Ptr{Nothing}
states::CuVector{UInt8}
@ -244,14 +246,14 @@ import ..Tracker: TrackedArray
using CUDAnative
using CuArrays: @cuindex, cudims
function copy_transpose!(dst::CuArray, src::CuArray)
function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
function kernel(dst, src)
I = @cuindex dst
dst[I...] = src[reverse(I)...]
return
end
blk, thr = cudims(dst)
@cuda (blk, thr) kernel(dst, src)
@cuda blocks=blk threads=thr kernel(dst, src)
return dst
end