CuRNN updates
This commit is contained in:
parent
8bea60d980
commit
5fd8ffa47e
@ -22,10 +22,10 @@ const RNN_ALGO_PERSIST_DYNAMIC = 2
|
|||||||
# LSTM: [weight, bias] × [input, hidden] × [input, forget, newmem, output]
|
# LSTM: [weight, bias] × [input, hidden] × [input, forget, newmem, output]
|
||||||
|
|
||||||
function params(w::CuVector, input, hidden, n = 1)
|
function params(w::CuVector, input, hidden, n = 1)
|
||||||
slice(offset, shape) = reshape(w[offset+(1:prod(shape))], shape)
|
slice(offset, shape) = reshape(w[offset.+(1:prod(shape))], shape)
|
||||||
wx = slice(0, (input, hidden*n))
|
wx = slice(0, (input, hidden*n))
|
||||||
wh = slice(length(wx), (hidden, hidden*n))
|
wh = slice(length(wx), (hidden, hidden*n))
|
||||||
bias = w[length(wx)+length(wh) + (1:hidden*n)]
|
bias = w[length(wx)+length(wh) .+ (1:hidden*n)]
|
||||||
(wx, wh), bias
|
(wx, wh), bias
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -65,8 +65,9 @@ function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T
|
|||||||
w = cuzeros(T, rnnParamSize(T, d[], input))
|
w = cuzeros(T, rnnParamSize(T, d[], input))
|
||||||
# TODO: avoid reserve allocation here
|
# TODO: avoid reserve allocation here
|
||||||
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[])
|
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[])
|
||||||
finalizer(rd, x ->
|
finalizer(rd) do x
|
||||||
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x))
|
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Nothing},),x)
|
||||||
|
end
|
||||||
return rd
|
return rd
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -220,17 +221,17 @@ end
|
|||||||
|
|
||||||
import ..Flux: Flux, relu
|
import ..Flux: Flux, relu
|
||||||
import ..Tracker: TrackedArray
|
import ..Tracker: TrackedArray
|
||||||
using CUDAnative
|
using .CuArrays.CUDAnative
|
||||||
using CuArrays: @cuindex, cudims
|
using .CuArrays: @cuindex, cudims
|
||||||
|
|
||||||
function copy_transpose!(dst::CuArray, src::CuArray)
|
function LinearAlgebra.copy_transpose!(dst::CuArray, src::CuArray)
|
||||||
function kernel(dst, src)
|
function kernel(dst, src)
|
||||||
I = @cuindex dst
|
I = @cuindex dst
|
||||||
dst[I...] = src[reverse(I)...]
|
dst[I...] = src[reverse(I)...]
|
||||||
return
|
return
|
||||||
end
|
end
|
||||||
blk, thr = cudims(dst)
|
blk, thr = cudims(dst)
|
||||||
@cuda (blk, thr) kernel(dst, src)
|
@cuda blocks=blk threads=thr kernel(dst, src)
|
||||||
return dst
|
return dst
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -303,7 +304,7 @@ end
|
|||||||
h_ = hBatch(x, data(h))
|
h_ = hBatch(x, data(h))
|
||||||
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
|
dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
|
||||||
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
||||||
nobacksies(:RNN, (dx, unbroadcast(size(h), dh), dWi.', dWh.', db))
|
nobacksies(:RNN, (dx, unbroadcast(size(h), dh), transpose(dWi), transpose(dWh), db))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -318,6 +319,6 @@ end
|
|||||||
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
(dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
|
||||||
nobacksies(:RNN,
|
nobacksies(:RNN,
|
||||||
(dx, unbroadcast(size(h), dh), unbroadcast(size(c), dc),
|
(dx, unbroadcast(size(h), dh), unbroadcast(size(c), dc),
|
||||||
dWi.', dWh.', db))
|
transpose(dWi), transpose(dWh), db))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user