move setweights and copy_transpose
This commit is contained in:
parent
5baebf48f4
commit
c5e56b7e04
|
@ -106,7 +106,7 @@ version = "4.0.0"
|
|||
|
||||
[[CuArrays]]
|
||||
deps = ["AbstractFFTs", "Adapt", "CEnum", "CUDAapi", "CUDAdrv", "CUDAnative", "GPUArrays", "LinearAlgebra", "MacroTools", "NNlib", "Printf", "Random", "Requires", "SparseArrays", "TimerOutputs"]
|
||||
git-tree-sha1 = "155349d2c40568a23cbc4599f0e17e2fdf1bbbcc"
|
||||
git-tree-sha1 = "63b4a10d3a4f22ef215d0970483b18296717d1fb"
|
||||
repo-rev = "tb/flux"
|
||||
repo-url = "https://github.com/JuliaGPU/CuArrays.jl.git"
|
||||
uuid = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
|
||||
|
|
|
@ -11,7 +11,6 @@ Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
|
|||
CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
|
||||
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
|
||||
Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
|
||||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
|
||||
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
|
||||
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
||||
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
|
||||
|
|
|
@ -1,32 +1,12 @@
|
|||
import ..Flux: Flux, relu
|
||||
using CuArrays.CUDAnative
|
||||
using CuArrays: @cuindex, cudims
|
||||
using LinearAlgebra
|
||||
|
||||
function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
|
||||
function kernel(dst, src)
|
||||
I = @cuindex dst
|
||||
dst[I...] = src[reverse(I)...]
|
||||
return
|
||||
end
|
||||
blk, thr = cudims(dst)
|
||||
@cuda blocks=blk threads=thr kernel(dst, src)
|
||||
return dst
|
||||
end
|
||||
|
||||
CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuArray{T,2},<:CuArray{T,1}}
|
||||
CuGRU{T} = Flux.GRUCell{<:CuArray{T,2},<:CuArray{T,1}}
|
||||
CuLSTM{T} = Flux.LSTMCell{<:CuArray{T,2},<:CuArray{T,1}}
|
||||
CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}}
|
||||
|
||||
function copyparams!(m::CuRNNs, d::CUDNN.RNNDesc)
|
||||
Wi, Wh = d.weights
|
||||
copy_transpose!(Wi, m.Wi)
|
||||
copy_transpose!(Wh, m.Wh)
|
||||
copy_transpose!(d.bias, m.b)
|
||||
return
|
||||
end
|
||||
|
||||
function CUDNN.RNNDesc(m::CuRNNs{T}) where T
|
||||
h, i = length(m.h), size(m.Wi, 2)
|
||||
mode = m isa CuRNN ?
|
||||
|
@ -40,7 +20,7 @@ const descs = WeakKeyDict()
|
|||
|
||||
function desc(rnn)
|
||||
d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = CUDNN.RNNDesc(rnn))
|
||||
copyparams!(rnn, d)
|
||||
CUDNN.setweights!(d, rnn.Wi, rnn.Wh, rnn.b)
|
||||
return d
|
||||
end
|
||||
|
||||
|
|
|
@ -22,8 +22,8 @@ end
|
|||
rand(10, batch_size)
|
||||
cux = gpu(x)
|
||||
|
||||
y, back = forward((r, x) -> (r(x)), rnn, x)
|
||||
cuy, cuback = forward((r, x) -> (r(x)), curnn, cux)
|
||||
y, back = forward((r, x) -> r(x), rnn, x)
|
||||
cuy, cuback = forward((r, x) -> r(x), curnn, cux)
|
||||
|
||||
@test y ≈ collect(cuy)
|
||||
@test haskey(Flux.CUDA.descs, curnn.cell)
|
||||
|
|
Loading…
Reference in New Issue