fix reserve usage
This commit is contained in:
parent
bc452fcd81
commit
fcbdc49d6b
@ -57,7 +57,6 @@ mutable struct RNNDesc{T}
|
|||||||
params::CuVector{T}
|
params::CuVector{T}
|
||||||
weights::NTuple{2,CuMatrix{T}}
|
weights::NTuple{2,CuMatrix{T}}
|
||||||
bias::CuVector{T}
|
bias::CuVector{T}
|
||||||
reserve::CuVector{UInt8}
|
|
||||||
ptr::Ptr{Void}
|
ptr::Ptr{Void}
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -86,7 +85,7 @@ function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T
|
|||||||
|
|
||||||
w = cuzeros(T, rnnParamSize(T, d[], input))
|
w = cuzeros(T, rnnParamSize(T, d[], input))
|
||||||
# TODO: avoid reserve allocation here
|
# TODO: avoid reserve allocation here
|
||||||
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., CuVector{UInt8}(1), d[])
|
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[])
|
||||||
finalizer(rd, x ->
|
finalizer(rd, x ->
|
||||||
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},),x))
|
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},),x))
|
||||||
return rd
|
return rd
|
||||||
@ -116,14 +115,9 @@ function rnnTrainingReserveSize(r::RNNDesc, seqlen, xdesc)
|
|||||||
return Int(size[])
|
return Int(size[])
|
||||||
end
|
end
|
||||||
|
|
||||||
function getreserve(r::RNNDesc, seqlen, xdesc)
|
|
||||||
sz = rnnTrainingReserveSize(r, seqlen, xdesc)
|
|
||||||
sz ≤ length(r.reserve) ? r.reserve : (r.reserve = CuVector{UInt8}(sz))
|
|
||||||
end
|
|
||||||
|
|
||||||
function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
||||||
workspace, reserve=nothing; train = (reserve ≠ nothing)) where T
|
workspace, reserve=nothing) where T
|
||||||
if !train
|
if reserve == nothing
|
||||||
@check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t,
|
@check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t,
|
||||||
(Ptr{Void}, Ptr{Void}, Cint,
|
(Ptr{Void}, Ptr{Void}, Cint,
|
||||||
Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T},
|
Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T},
|
||||||
@ -158,7 +152,7 @@ hBatch(x::AbstractVector, h::CuVector) = h
|
|||||||
hBatch(x::AbstractMatrix, h::CuVector) = h .* cuones(1, size(x, 2))
|
hBatch(x::AbstractMatrix, h::CuVector) = h .* cuones(1, size(x, 2))
|
||||||
hBatch(x::AbstractMatrix, h::CuMatrix) = h .* cuones(1, size(h,2) == 1 ? size(x,2) : 1)
|
hBatch(x::AbstractMatrix, h::CuMatrix) = h .* cuones(1, size(h,2) == 1 ? size(x,2) : 1)
|
||||||
|
|
||||||
function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing; train = false) where T
|
function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing, train = Val{false}) where T
|
||||||
h = hBatch(x, h_)
|
h = hBatch(x, h_)
|
||||||
c = c_ == nothing ? nothing : hBatch(x, c_)
|
c = c_ == nothing ? nothing : hBatch(x, c_)
|
||||||
@assert size(x, 1) == rnn.input
|
@assert size(x, 1) == rnn.input
|
||||||
@ -170,7 +164,9 @@ function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing; t
|
|||||||
ho = similar(h)
|
ho = similar(h)
|
||||||
ydesc = xDesc(y)
|
ydesc = xDesc(y)
|
||||||
workspace = getworkspace(rnn, seqLength, xdesc)
|
workspace = getworkspace(rnn, seqLength, xdesc)
|
||||||
reserve = train ? getreserve(rnn, seqLength, xdesc) : rnn.reserve
|
reserve = train == Val{true} ?
|
||||||
|
CuVector{UInt8}(rnnTrainingReserveSize(rnn, seqLength, xdesc)) :
|
||||||
|
nothing
|
||||||
co = c == nothing ? c : similar(c)
|
co = c == nothing ? c : similar(c)
|
||||||
cudnnRNNForward(rnn, seqLength,
|
cudnnRNNForward(rnn, seqLength,
|
||||||
xdesc, x,
|
xdesc, x,
|
||||||
@ -180,10 +176,14 @@ function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing; t
|
|||||||
ydesc, y,
|
ydesc, y,
|
||||||
hDesc(ho)...,
|
hDesc(ho)...,
|
||||||
hDesc(co)...,
|
hDesc(co)...,
|
||||||
workspace, reserve, train = train)
|
workspace, reserve)
|
||||||
return c == nothing ? (y, ho) : (y, ho, co)
|
result = c == nothing ? (y, ho) : (y, ho, co)
|
||||||
|
return train == Val{true} ? (reserve, result) : result
|
||||||
end
|
end
|
||||||
|
|
||||||
|
forwardTrain(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing) where T =
|
||||||
|
forward(rnn, x, h, c, Val{true})
|
||||||
|
|
||||||
function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
|
function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
|
||||||
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T
|
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T
|
||||||
@check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t,
|
@check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t,
|
||||||
@ -196,7 +196,9 @@ function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho
|
|||||||
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs))
|
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs))
|
||||||
end
|
end
|
||||||
|
|
||||||
function backwardData(rnn::RNNDesc{T}, y, dy, dho, dco, h, c) where T
|
function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T
|
||||||
|
# Same as above, any more efficient way?
|
||||||
|
dy = dy_ isa Integer ? zeros(y) : dy_
|
||||||
yd = xDesc(y)
|
yd = xDesc(y)
|
||||||
dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2))
|
dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2))
|
||||||
dh = similar(h)
|
dh = similar(h)
|
||||||
@ -205,12 +207,12 @@ function backwardData(rnn::RNNDesc{T}, y, dy, dho, dco, h, c) where T
|
|||||||
yd, y, yd, dy, hDesc(dho)..., hDesc(dco)...,
|
yd, y, yd, dy, hDesc(dho)..., hDesc(dco)...,
|
||||||
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
|
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
|
||||||
hDesc(h)..., hDesc(c)..., xDesc(dx), dx, hDesc(dh)..., hDesc(dc)...,
|
hDesc(h)..., hDesc(c)..., xDesc(dx), dx, hDesc(dh)..., hDesc(dc)...,
|
||||||
workspace[], rnn.reserve)
|
workspace[], reserve)
|
||||||
return c == nothing ? (dx, dh) : (dx, dh, dc)
|
return c == nothing ? (dx, dh) : (dx, dh, dc)
|
||||||
end
|
end
|
||||||
|
|
||||||
backwardData(rnn, y, dy, dho, hx) =
|
backwardData(rnn, y, dy, dho, hx, reserve) =
|
||||||
backwardData(rnn, y, dy, dho, nothing, hx, nothing)
|
backwardData(rnn, y, dy, dho, nothing, hx, nothing, reserve)
|
||||||
|
|
||||||
function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, dwd, dw,
|
function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, dwd, dw,
|
||||||
workspace, reserve) where T
|
workspace, reserve) where T
|
||||||
@ -226,12 +228,12 @@ function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, d
|
|||||||
workspace, length(workspace), dwd, dw, reserve, length(reserve))
|
workspace, length(workspace), dwd, dw, reserve, length(reserve))
|
||||||
end
|
end
|
||||||
|
|
||||||
function backwardWeights(rnn::RNNDesc{T}, x, h, y) where T
|
function backwardWeights(rnn::RNNDesc{T}, x, h, y, reserve) where T
|
||||||
dw = zeros(rnn.params)
|
dw = zeros(rnn.params)
|
||||||
cudnnRNNBackwardWeights(rnn, 1,
|
cudnnRNNBackwardWeights(rnn, 1,
|
||||||
xDesc(x), x, hDesc(h)..., xDesc(y), y,
|
xDesc(x), x, hDesc(h)..., xDesc(y), y,
|
||||||
FilterDesc(T, (1, 1, length(dw))), dw,
|
FilterDesc(T, (1, 1, length(dw))), dw,
|
||||||
workspace[], rnn.reserve)
|
workspace[], reserve)
|
||||||
return params(dw, rnn.input, rnn.hidden, ngates(rnn))
|
return params(dw, rnn.input, rnn.hidden, ngates(rnn))
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -286,12 +288,19 @@ end
|
|||||||
|
|
||||||
import Flux.Tracker: data, isleaf, istracked, track, back_, @back, unbroadcast
|
import Flux.Tracker: data, isleaf, istracked, track, back_, @back, unbroadcast
|
||||||
|
|
||||||
# TODO: fix reserve space usage
|
mutable struct RNNCall{R}
|
||||||
struct RNNCall{R}
|
|
||||||
rnn::R
|
rnn::R
|
||||||
|
reserve::CuVector{UInt8}
|
||||||
|
RNNCall{R}(rnn::R) where R = new(rnn)
|
||||||
end
|
end
|
||||||
|
|
||||||
(c::RNNCall)(args...) = forward(desc(c.rnn), args..., train = true)
|
RNNCall(rnn) = RNNCall{typeof(rnn)}(rnn)
|
||||||
|
|
||||||
|
function (c::RNNCall)(args...)
|
||||||
|
rs, result = forwardTrain(desc(c.rnn), args...)
|
||||||
|
c.reserve = rs
|
||||||
|
return result
|
||||||
|
end
|
||||||
|
|
||||||
istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...))
|
istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...))
|
||||||
|
|
||||||
@ -331,10 +340,10 @@ function back_(m::RNNCall{<:Union{CuRNN,CuGRU}}, y_, Δ, x, h)
|
|||||||
y, ho = y_
|
y, ho = y_
|
||||||
dy, dho = Δ
|
dy, dho = Δ
|
||||||
h_ = hBatch(x, data(h))
|
h_ = hBatch(x, data(h))
|
||||||
dx, dh = backwardData(descs[m.rnn], y, dy, dho, h_)
|
dx, dh = backwardData(descs[m.rnn], y, dy, dho, h_, m.reserve)
|
||||||
@back(x, dx)
|
@back(x, dx)
|
||||||
@back(h, unbroadcast(h, dh))
|
@back(h, unbroadcast(h, dh))
|
||||||
(dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y)
|
(dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y, m.reserve)
|
||||||
# We don't have to make this assumption, it's just slightly more complex.
|
# We don't have to make this assumption, it's just slightly more complex.
|
||||||
@assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b)))
|
@assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b)))
|
||||||
istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi)
|
istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi)
|
||||||
@ -347,11 +356,11 @@ function back_(m::RNNCall{<:CuLSTM}, y_, Δ, x, h, c)
|
|||||||
dy, dho, dco = Δ
|
dy, dho, dco = Δ
|
||||||
h_ = hBatch(x, data(h))
|
h_ = hBatch(x, data(h))
|
||||||
c_ = hBatch(x, data(c))
|
c_ = hBatch(x, data(c))
|
||||||
dx, dh, dc = backwardData(descs[m.rnn], y, dy, dho, dco, h_, c_)
|
dx, dh, dc = backwardData(descs[m.rnn], y, dy, dho, dco, h_, c_, m.reserve)
|
||||||
@back(x, dx)
|
@back(x, dx)
|
||||||
@back(h, unbroadcast(h, dh))
|
@back(h, unbroadcast(h, dh))
|
||||||
@back(c, unbroadcast(h, dc))
|
@back(c, unbroadcast(h, dc))
|
||||||
(dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y)
|
(dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y, m.reserve)
|
||||||
@assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b)))
|
@assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b)))
|
||||||
istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi)
|
istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi)
|
||||||
istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh)
|
istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh)
|
||||||
|
@ -130,7 +130,7 @@ end
|
|||||||
|
|
||||||
function (m::LSTMCell)(h_, x)
|
function (m::LSTMCell)(h_, x)
|
||||||
h, c = h_ # TODO: nicer syntax on 0.7
|
h, c = h_ # TODO: nicer syntax on 0.7
|
||||||
b, o = m.b, length(h)
|
b, o = m.b, size(h, 1)
|
||||||
g = m.Wi*x .+ m.Wh*h .+ b
|
g = m.Wi*x .+ m.Wh*h .+ b
|
||||||
input = σ.(gate(g, o, 1))
|
input = σ.(gate(g, o, 1))
|
||||||
forget = σ.(gate(g, o, 2))
|
forget = σ.(gate(g, o, 2))
|
||||||
@ -173,7 +173,7 @@ GRUCell(in, out; init = glorot_uniform) =
|
|||||||
param(zeros(out*3)), param(initn(out)))
|
param(zeros(out*3)), param(initn(out)))
|
||||||
|
|
||||||
function (m::GRUCell)(h, x)
|
function (m::GRUCell)(h, x)
|
||||||
b, o = m.b, length(h)
|
b, o = m.b, size(h, 1)
|
||||||
gx, gh = m.Wi*x, m.Wh*h
|
gx, gh = m.Wi*x, m.Wh*h
|
||||||
r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1))
|
r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1))
|
||||||
z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2))
|
z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2))
|
||||||
|
@ -1,7 +1,4 @@
|
|||||||
using Flux, CuArrays, Base.Test
|
using Flux, CuArrays, Base.Test
|
||||||
using Flux.CUDA
|
|
||||||
using Flux.CUDA: RNNDesc
|
|
||||||
using CUDAnative
|
|
||||||
|
|
||||||
info("Testing Flux/CUDNN")
|
info("Testing Flux/CUDNN")
|
||||||
|
|
||||||
@ -11,8 +8,8 @@ info("Testing Flux/CUDNN")
|
|||||||
cux = cu(x)
|
cux = cu(x)
|
||||||
rnn = R(10, 5)
|
rnn = R(10, 5)
|
||||||
curnn = mapleaves(cu, rnn)
|
curnn = mapleaves(cu, rnn)
|
||||||
y = rnn(x)
|
y = (rnn(x); rnn(x))
|
||||||
cuy = curnn(cux)
|
cuy = (curnn(cux); curnn(cux))
|
||||||
|
|
||||||
@test y.data ≈ collect(cuy.data)
|
@test y.data ≈ collect(cuy.data)
|
||||||
@test haskey(Flux.CUDA.descs, curnn.cell)
|
@test haskey(Flux.CUDA.descs, curnn.cell)
|
||||||
|
Loading…
Reference in New Issue
Block a user