backward passes
This commit is contained in:
parent
f866fbe575
commit
30b3437c56
@ -97,6 +97,16 @@ function rnnWorkspaceSize(r::RNNDesc, seqlen, xdesc)
|
||||
return Int(size[])
|
||||
end
|
||||
|
||||
const workspace = [CuVector{UInt8}(1)]
|
||||
|
||||
getworkspace(bytes) =
|
||||
length(workspace[]) ≥ bytes ?
|
||||
workspace[] :
|
||||
(workspace[] = CuVector{UInt8}(bytes))
|
||||
|
||||
getworkspace(r::RNNDesc, seqlen, xdesc) =
|
||||
getworkspace(rnnWorkspaceSize(rnn, seqlen, xdesc))
|
||||
|
||||
function rnnTrainingReserveSize(r::RNNDesc, seqlen, xdesc)
|
||||
size = Csize_t[0]
|
||||
@check ccall((:cudnnGetRNNTrainingReserveSize,libcudnn), cudnnStatus_t, (Ptr{Void}, Ptr{Void}, Cint, Ptr{Ptr{Void}}, Ptr{Csize_t}),
|
||||
@ -109,12 +119,14 @@ function getreserve(r::RNNDesc, seqlen, xdesc)
|
||||
sz ≤ length(r.reserve) ? r.reserve : (r.reserve = CuVector{UInt8}(sz))
|
||||
end
|
||||
|
||||
function cudnnRNNForward(::Type{T}, rnn, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
||||
function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
||||
workspace, reserve=nothing; train = (reserve ≠ nothing)) where T
|
||||
if !train
|
||||
@check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t,
|
||||
(Ptr{Void}, Ptr{Void}, Cint,
|
||||
Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T},
|
||||
Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T},
|
||||
Ptr{Void}, Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T},
|
||||
Ptr{Void}, Ptr{T},
|
||||
Ptr{Void}, Csize_t),
|
||||
libcudnn_handle[], rnn, seqlen,
|
||||
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
||||
@ -130,6 +142,8 @@ function cudnnRNNForward(::Type{T}, rnn, seqlen, xd, x, hd, h, cd, c, wd, w, yd,
|
||||
end
|
||||
end
|
||||
|
||||
xDesc(x) = [TensorDesc(eltype(x), (1, size(x, 1), size(x, 2)))]
|
||||
|
||||
hDesc(h::Void) = C_NULL, C_NULL
|
||||
function hDesc(h::CuArray)
|
||||
TensorDesc(eltype(h), (size(h, 1), size(h, 2), 1)), h
|
||||
@ -140,26 +154,73 @@ function forward(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing; tra
|
||||
@assert size(h, 1) == rnn.hidden
|
||||
@assert size(x, 2) == size(h, 2)
|
||||
seqLength = 1
|
||||
xdesc = [TensorDesc(T, (1, size(x, 1), size(x, 2)))]
|
||||
xdesc = xDesc(x)
|
||||
y = x isa AbstractVector ? similar(x, rnn.hidden) : similar(x, rnn.hidden, size(x, 2))
|
||||
ydesc = [TensorDesc(T, (1, size(y, 1), size(y, 2)))]
|
||||
workspace = CuVector{UInt8}(rnnWorkspaceSize(rnn, seqLength, xdesc)) # TODO: reuse this
|
||||
ydesc = xDesc(y)
|
||||
workspace = getworkspace(rnn, seqLength, xdesc)
|
||||
reserve = train ? getreserve(rnn, seqLength, xdesc) : rnn.reserve
|
||||
cy = c == nothing ? c : similar(c)
|
||||
cudnnRNNForward(T, rnn, seqLength,
|
||||
co = c == nothing ? c : similar(c)
|
||||
cudnnRNNForward(rnn, seqLength,
|
||||
xdesc, x,
|
||||
hDesc(h)...,
|
||||
hDesc(c)...,
|
||||
TensorDesc(T, (1, 1, length(rnn.params))), rnn.params,
|
||||
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
|
||||
ydesc, y,
|
||||
C_NULL, C_NULL, # hout
|
||||
hDesc(cy)...,
|
||||
hDesc(co)...,
|
||||
workspace, reserve, train = train)
|
||||
if c == nothing
|
||||
return y, y
|
||||
else
|
||||
return y, y, cy
|
||||
end
|
||||
return c == nothing ? (y, y) : (y, y, co)
|
||||
end
|
||||
|
||||
function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
|
||||
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T
|
||||
@check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t,
|
||||
(Ptr{Void}, Ptr{Void}, Cint,
|
||||
Ptr{Ptr{Void}}, Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T},
|
||||
Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void},
|
||||
Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T},
|
||||
Ptr{Void}, Csize_t, Ptr{Void}, Csize_t),
|
||||
libcudnn_handle[], rnn, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
|
||||
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs))
|
||||
end
|
||||
|
||||
function backwardData(rnn::RNNDesc{T}, y, dy, dho, dco, h, c) where T
|
||||
yd = xDesc(y)
|
||||
dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2))
|
||||
dh = similar(h)
|
||||
dc = c == nothing ? nothing : similar(c)
|
||||
cudnnRNNBackwardData(rnn, 1,
|
||||
yd, y, yd, dy, hDesc(dho)..., hDesc(dco)...,
|
||||
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
|
||||
hDesc(h)..., hDesc(c)..., xDesc(dx), dx, hDesc(dh)..., hDesc(dc)...,
|
||||
workspace[], rnn.reserve)
|
||||
return c == nothing ? (dx, dh) : (dx, dh, dc)
|
||||
end
|
||||
|
||||
backwardData(rnn, y, dy, dho, hx) =
|
||||
backwardData(rnn, y, dy, dho, nothing, hx, nothing)
|
||||
|
||||
function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, dwd, dw,
|
||||
workspace, reserve) where T
|
||||
@check ccall((:cudnnRNNBackwardWeights,libcudnn), cudnnStatus_t,
|
||||
(Ptr{Void}, Ptr{Void}, Cint, # handle, rnnDesc, seqLength
|
||||
Ptr{Ptr{Void}}, Ptr{T}, #x
|
||||
Ptr{Void}, Ptr{T}, #hx
|
||||
Ptr{Ptr{Void}}, Ptr{T}, #y
|
||||
Ptr{Void}, Csize_t, #ws
|
||||
Ptr{Void}, Ptr{T}, #dw
|
||||
Ptr{Void}, Csize_t), #rs
|
||||
libcudnn_handle[], rnn, seqlen, xd, x, hd, h, yd, y,
|
||||
workspace, length(workspace), dwd, dw, reserve, length(reserve))
|
||||
end
|
||||
|
||||
function backwardWeights(rnn::RNNDesc{T}, x, h, y) where T
|
||||
dw = zeros(rnn.params)
|
||||
cudnnRNNBackwardWeights(rnn, 1,
|
||||
xDesc(x), x, hDesc(h)..., xDesc(y), y,
|
||||
FilterDesc(T, (1, 1, length(dw))), dw,
|
||||
workspace[], rnn.reserve)
|
||||
return params(dw, rnn.input, rnn.hidden)
|
||||
end
|
||||
|
||||
# Interface
|
||||
@ -194,7 +255,7 @@ function copyparams!(m::CuRNNs, d::RNNDesc)
|
||||
return
|
||||
end
|
||||
|
||||
function RNNDesc(m::CuRNNs{T}) where {T}
|
||||
function RNNDesc(m::CuRNNs{T}) where T
|
||||
h, i = length(m.h), size(m.Wi, 2)
|
||||
mode = m isa CuRNN ?
|
||||
(m.σ == tanh ? RNN_TANH : RNN_RELU) :
|
||||
|
Loading…
Reference in New Issue
Block a user