Merge pull request #1 from SimonDanisch/patch-1
fix copy_transpose! for cuda
This commit is contained in:
commit
a3ab1cbb98
@ -1,6 +1,8 @@
|
|||||||
using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, libcudnn_handle,
|
using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, libcudnn_handle,
|
||||||
cudnnDataType, TensorDesc, FilterDesc
|
cudnnDataType, TensorDesc, FilterDesc
|
||||||
|
|
||||||
|
using LinearAlgebra
|
||||||
|
|
||||||
mutable struct DropoutDesc
|
mutable struct DropoutDesc
|
||||||
ptr::Ptr{Nothing}
|
ptr::Ptr{Nothing}
|
||||||
states::CuVector{UInt8}
|
states::CuVector{UInt8}
|
||||||
@ -244,14 +246,14 @@ import ..Tracker: TrackedArray
|
|||||||
using CUDAnative
|
using CUDAnative
|
||||||
using CuArrays: @cuindex, cudims
|
using CuArrays: @cuindex, cudims
|
||||||
|
|
||||||
function copy_transpose!(dst::CuArray, src::CuArray)
|
function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
|
||||||
function kernel(dst, src)
|
function kernel(dst, src)
|
||||||
I = @cuindex dst
|
I = @cuindex dst
|
||||||
dst[I...] = src[reverse(I)...]
|
dst[I...] = src[reverse(I)...]
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
blk, thr = cudims(dst)
|
blk, thr = cudims(dst)
|
||||||
@cuda (blk, thr) kernel(dst, src)
|
@cuda blocks=blk threads=thr kernel(dst, src)
|
||||||
return dst
|
return dst
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user