CuRNN updates

This commit is contained in:
Avik Pal 2018-09-11 15:44:07 +05:30
parent 8bea60d980
commit 5fd8ffa47e

View File

@ -22,10 +22,10 @@ const RNN_ALGO_PERSIST_DYNAMIC = 2
# LSTM: [weight, bias] × [input, hidden] × [input, forget, newmem, output]
function params(w::CuVector, input, hidden, n = 1)
slice(offset, shape) = reshape(w[offset+(1:prod(shape))], shape)
slice(offset, shape) = reshape(w[offset.+(1:prod(shape))], shape)
wx = slice(0, (input, hidden*n))
wh = slice(length(wx), (hidden, hidden*n))
bias = w[length(wx)+length(wh) + (1:hidden*n)]
bias = w[length(wx)+length(wh) .+ (1:hidden*n)]
(wx, wh), bias
end
@ -65,8 +65,9 @@ function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T
w = cuzeros(T, rnnParamSize(T, d[], input))
# TODO: avoid reserve allocation here
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[])
finalizer(rd, x ->
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x))
finalizer(rd) do x
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x)
end
return rd
end
@ -220,17 +221,17 @@ end
import ..Flux: Flux, relu
import ..Tracker: TrackedArray
using CUDAnative
using CuArrays: @cuindex, cudims
using .CuArrays.CUDAnative
using .CuArrays: @cuindex, cudims
function copy_transpose!(dst::CuArray, src::CuArray)
function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
function kernel(dst, src)
I = @cuindex dst
dst[I...] = src[reverse(I)...]
return
end
blk, thr = cudims(dst)
@cuda (blk, thr) kernel(dst, src)
@cuda blocks=blk threads=thr kernel(dst, src)
return dst
end
@ -303,7 +304,7 @@ end
h_ = hBatch(x, data(h))
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
nobacksies(:RNN, (dx, unbroadcast(size(h), dh), dWi.', dWh.', db))
nobacksies(:RNN, (dx, unbroadcast(size(h), dh), transpose(dWi), transpose(dWh), db))
end
end
@ -318,6 +319,6 @@ end
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
nobacksies(:RNN,
(dx, unbroadcast(size(h), dh), unbroadcast(size(c), dc),
dWi.', dWh.', db))
transpose(dWi), transpose(dWh), db))
end
end