Merge pull request #1 from SimonDanisch/patch-1
fix copy_transpose! for cuda
This commit is contained in:
commit
a3ab1cbb98
@ -1,6 +1,8 @@
|
||||
using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, libcudnn_handle,
|
||||
cudnnDataType, TensorDesc, FilterDesc
|
||||
|
||||
using LinearAlgebra
|
||||
|
||||
mutable struct DropoutDesc
|
||||
ptr::Ptr{Nothing}
|
||||
states::CuVector{UInt8}
|
||||
@ -244,14 +246,14 @@ import ..Tracker: TrackedArray
|
||||
using CUDAnative
|
||||
using CuArrays: @cuindex, cudims
|
||||
|
||||
function copy_transpose!(dst::CuArray, src::CuArray)
|
||||
function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
|
||||
function kernel(dst, src)
|
||||
I = @cuindex dst
|
||||
dst[I...] = src[reverse(I)...]
|
||||
return
|
||||
end
|
||||
blk, thr = cudims(dst)
|
||||
@cuda (blk, thr) kernel(dst, src)
|
||||
@cuda blocks=blk threads=thr kernel(dst, src)
|
||||
return dst
|
||||
end
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user