CuRNN updates
This commit is contained in:
parent
8bea60d980
commit
5fd8ffa47e
@ -22,10 +22,10 @@ const RNN_ALGO_PERSIST_DYNAMIC = 2
|
||||
# LSTM: [weight, bias] × [input, hidden] × [input, forget, newmem, output]
|
||||
|
||||
function params(w::CuVector, input, hidden, n = 1)
|
||||
slice(offset, shape) = reshape(w[offset+(1:prod(shape))], shape)
|
||||
slice(offset, shape) = reshape(w[offset.+(1:prod(shape))], shape)
|
||||
wx = slice(0, (input, hidden*n))
|
||||
wh = slice(length(wx), (hidden, hidden*n))
|
||||
bias = w[length(wx)+length(wh) + (1:hidden*n)]
|
||||
bias = w[length(wx)+length(wh) .+ (1:hidden*n)]
|
||||
(wx, wh), bias
|
||||
end
|
||||
|
||||
@ -65,8 +65,9 @@ function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T
|
||||
w = cuzeros(T, rnnParamSize(T, d[], input))
|
||||
# TODO: avoid reserve allocation here
|
||||
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[])
|
||||
finalizer(rd, x ->
|
||||
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x))
|
||||
finalizer(rd) do x
|
||||
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x)
|
||||
end
|
||||
return rd
|
||||
end
|
||||
|
||||
@ -220,17 +221,17 @@ end
|
||||
|
||||
import ..Flux: Flux, relu
|
||||
import ..Tracker: TrackedArray
|
||||
using CUDAnative
|
||||
using CuArrays: @cuindex, cudims
|
||||
using .CuArrays.CUDAnative
|
||||
using .CuArrays: @cuindex, cudims
|
||||
|
||||
function copy_transpose!(dst::CuArray, src::CuArray)
|
||||
function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
|
||||
function kernel(dst, src)
|
||||
I = @cuindex dst
|
||||
dst[I...] = src[reverse(I)...]
|
||||
return
|
||||
end
|
||||
blk, thr = cudims(dst)
|
||||
@cuda (blk, thr) kernel(dst, src)
|
||||
@cuda blocks=blk threads=thr kernel(dst, src)
|
||||
return dst
|
||||
end
|
||||
|
||||
@ -303,7 +304,7 @@ end
|
||||
h_ = hBatch(x, data(h))
|
||||
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
|
||||
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
||||
nobacksies(:RNN, (dx, unbroadcast(size(h), dh), dWi.', dWh.', db))
|
||||
nobacksies(:RNN, (dx, unbroadcast(size(h), dh), transpose(dWi), transpose(dWh), db))
|
||||
end
|
||||
end
|
||||
|
||||
@ -318,6 +319,6 @@ end
|
||||
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
||||
nobacksies(:RNN,
|
||||
(dx, unbroadcast(size(h), dh), unbroadcast(size(c), dc),
|
||||
dWi.', dWh.', db))
|
||||
transpose(dWi), transpose(dWh), db))
|
||||
end
|
||||
end
|
||||
|
Loading…
Reference in New Issue
Block a user