move setweights and copy_transpose
This commit is contained in:
parent
5baebf48f4
commit
c5e56b7e04
|
@ -106,7 +106,7 @@ version = "4.0.0"
|
||||||
|
|
||||||
[[CuArrays]]
|
[[CuArrays]]
|
||||||
deps = ["AbstractFFTs", "Adapt", "CEnum", "CUDAapi", "CUDAdrv", "CUDAnative", "GPUArrays", "LinearAlgebra", "MacroTools", "NNlib", "Printf", "Random", "Requires", "SparseArrays", "TimerOutputs"]
|
deps = ["AbstractFFTs", "Adapt", "CEnum", "CUDAapi", "CUDAdrv", "CUDAnative", "GPUArrays", "LinearAlgebra", "MacroTools", "NNlib", "Printf", "Random", "Requires", "SparseArrays", "TimerOutputs"]
|
||||||
git-tree-sha1 = "155349d2c40568a23cbc4599f0e17e2fdf1bbbcc"
|
git-tree-sha1 = "63b4a10d3a4f22ef215d0970483b18296717d1fb"
|
||||||
repo-rev = "tb/flux"
|
repo-rev = "tb/flux"
|
||||||
repo-url = "https://github.com/JuliaGPU/CuArrays.jl.git"
|
repo-url = "https://github.com/JuliaGPU/CuArrays.jl.git"
|
||||||
uuid = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
|
uuid = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
|
||||||
|
|
|
@ -11,7 +11,6 @@ Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
|
||||||
CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
|
CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
|
||||||
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
|
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
|
||||||
Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
|
Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
|
||||||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
|
|
||||||
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
|
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
|
||||||
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
|
||||||
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
|
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
|
||||||
|
|
|
@ -1,32 +1,12 @@
|
||||||
import ..Flux: Flux, relu
|
import ..Flux: Flux, relu
|
||||||
using CuArrays.CUDAnative
|
using CuArrays.CUDAnative
|
||||||
using CuArrays: @cuindex, cudims
|
using CuArrays: @cuindex, cudims
|
||||||
using LinearAlgebra
|
|
||||||
|
|
||||||
function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
|
|
||||||
function kernel(dst, src)
|
|
||||||
I = @cuindex dst
|
|
||||||
dst[I...] = src[reverse(I)...]
|
|
||||||
return
|
|
||||||
end
|
|
||||||
blk, thr = cudims(dst)
|
|
||||||
@cuda blocks=blk threads=thr kernel(dst, src)
|
|
||||||
return dst
|
|
||||||
end
|
|
||||||
|
|
||||||
CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuArray{T,2},<:CuArray{T,1}}
|
CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuArray{T,2},<:CuArray{T,1}}
|
||||||
CuGRU{T} = Flux.GRUCell{<:CuArray{T,2},<:CuArray{T,1}}
|
CuGRU{T} = Flux.GRUCell{<:CuArray{T,2},<:CuArray{T,1}}
|
||||||
CuLSTM{T} = Flux.LSTMCell{<:CuArray{T,2},<:CuArray{T,1}}
|
CuLSTM{T} = Flux.LSTMCell{<:CuArray{T,2},<:CuArray{T,1}}
|
||||||
CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}}
|
CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}}
|
||||||
|
|
||||||
function copyparams!(m::CuRNNs, d::CUDNN.RNNDesc)
|
|
||||||
Wi, Wh = d.weights
|
|
||||||
copy_transpose!(Wi, m.Wi)
|
|
||||||
copy_transpose!(Wh, m.Wh)
|
|
||||||
copy_transpose!(d.bias, m.b)
|
|
||||||
return
|
|
||||||
end
|
|
||||||
|
|
||||||
function CUDNN.RNNDesc(m::CuRNNs{T}) where T
|
function CUDNN.RNNDesc(m::CuRNNs{T}) where T
|
||||||
h, i = length(m.h), size(m.Wi, 2)
|
h, i = length(m.h), size(m.Wi, 2)
|
||||||
mode = m isa CuRNN ?
|
mode = m isa CuRNN ?
|
||||||
|
@ -40,7 +20,7 @@ const descs = WeakKeyDict()
|
||||||
|
|
||||||
function desc(rnn)
|
function desc(rnn)
|
||||||
d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = CUDNN.RNNDesc(rnn))
|
d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = CUDNN.RNNDesc(rnn))
|
||||||
copyparams!(rnn, d)
|
CUDNN.setweights!(d, rnn.Wi, rnn.Wh, rnn.b)
|
||||||
return d
|
return d
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -22,8 +22,8 @@ end
|
||||||
rand(10, batch_size)
|
rand(10, batch_size)
|
||||||
cux = gpu(x)
|
cux = gpu(x)
|
||||||
|
|
||||||
y, back = forward((r, x) -> (r(x)), rnn, x)
|
y, back = forward((r, x) -> r(x), rnn, x)
|
||||||
cuy, cuback = forward((r, x) -> (r(x)), curnn, cux)
|
cuy, cuback = forward((r, x) -> r(x), curnn, cux)
|
||||||
|
|
||||||
@test y ≈ collect(cuy)
|
@test y ≈ collect(cuy)
|
||||||
@test haskey(Flux.CUDA.descs, curnn.cell)
|
@test haskey(Flux.CUDA.descs, curnn.cell)
|
||||||
|
|
Loading…
Reference in New Issue