move setweights and copy_transpose

This commit is contained in:
Mike Innes 2019-09-17 17:22:35 +01:00
parent 5baebf48f4
commit c5e56b7e04
4 changed files with 4 additions and 25 deletions

View File

@ -106,7 +106,7 @@ version = "4.0.0"
[[CuArrays]] [[CuArrays]]
deps = ["AbstractFFTs", "Adapt", "CEnum", "CUDAapi", "CUDAdrv", "CUDAnative", "GPUArrays", "LinearAlgebra", "MacroTools", "NNlib", "Printf", "Random", "Requires", "SparseArrays", "TimerOutputs"] deps = ["AbstractFFTs", "Adapt", "CEnum", "CUDAapi", "CUDAdrv", "CUDAnative", "GPUArrays", "LinearAlgebra", "MacroTools", "NNlib", "Printf", "Random", "Requires", "SparseArrays", "TimerOutputs"]
git-tree-sha1 = "155349d2c40568a23cbc4599f0e17e2fdf1bbbcc" git-tree-sha1 = "63b4a10d3a4f22ef215d0970483b18296717d1fb"
repo-rev = "tb/flux" repo-rev = "tb/flux"
repo-url = "https://github.com/JuliaGPU/CuArrays.jl.git" repo-url = "https://github.com/JuliaGPU/CuArrays.jl.git"
uuid = "3a865a2d-5b23-5a0f-bc46-62713ec82fae" uuid = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"

View File

@ -11,7 +11,6 @@ Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae" CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab" DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d" Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"

View File

@ -1,32 +1,12 @@
import ..Flux: Flux, relu import ..Flux: Flux, relu
using CuArrays.CUDAnative using CuArrays.CUDAnative
using CuArrays: @cuindex, cudims using CuArrays: @cuindex, cudims
using LinearAlgebra
function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
function kernel(dst, src)
I = @cuindex dst
dst[I...] = src[reverse(I)...]
return
end
blk, thr = cudims(dst)
@cuda blocks=blk threads=thr kernel(dst, src)
return dst
end
CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuArray{T,2},<:CuArray{T,1}} CuRNN{T} = Flux.RNNCell{<:Union{typeof(tanh),typeof(relu)},<:CuArray{T,2},<:CuArray{T,1}}
CuGRU{T} = Flux.GRUCell{<:CuArray{T,2},<:CuArray{T,1}} CuGRU{T} = Flux.GRUCell{<:CuArray{T,2},<:CuArray{T,1}}
CuLSTM{T} = Flux.LSTMCell{<:CuArray{T,2},<:CuArray{T,1}} CuLSTM{T} = Flux.LSTMCell{<:CuArray{T,2},<:CuArray{T,1}}
CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}} CuRNNs{T} = Union{CuRNN{T},CuGRU{T},CuLSTM{T}}
function copyparams!(m::CuRNNs, d::CUDNN.RNNDesc)
Wi, Wh = d.weights
copy_transpose!(Wi, m.Wi)
copy_transpose!(Wh, m.Wh)
copy_transpose!(d.bias, m.b)
return
end
function CUDNN.RNNDesc(m::CuRNNs{T}) where T function CUDNN.RNNDesc(m::CuRNNs{T}) where T
h, i = length(m.h), size(m.Wi, 2) h, i = length(m.h), size(m.Wi, 2)
mode = m isa CuRNN ? mode = m isa CuRNN ?
@ -40,7 +20,7 @@ const descs = WeakKeyDict()
function desc(rnn) function desc(rnn)
d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = CUDNN.RNNDesc(rnn)) d = haskey(descs, rnn) ? descs[rnn] : (descs[rnn] = CUDNN.RNNDesc(rnn))
copyparams!(rnn, d) CUDNN.setweights!(d, rnn.Wi, rnn.Wh, rnn.b)
return d return d
end end

View File

@ -22,8 +22,8 @@ end
rand(10, batch_size) rand(10, batch_size)
cux = gpu(x) cux = gpu(x)
y, back = forward((r, x) -> (r(x)), rnn, x) y, back = forward((r, x) -> r(x), rnn, x)
cuy, cuback = forward((r, x) -> (r(x)), curnn, cux) cuy, cuback = forward((r, x) -> r(x), curnn, cux)
@test y collect(cuy) @test y collect(cuy)
@test haskey(Flux.CUDA.descs, curnn.cell) @test haskey(Flux.CUDA.descs, curnn.cell)