fix reserve usage

This commit is contained in:
Mike J Innes 2018-02-08 10:24:59 +00:00
parent bc452fcd81
commit fcbdc49d6b
3 changed files with 39 additions and 33 deletions

View File

@ -57,7 +57,6 @@ mutable struct RNNDesc{T}
params::CuVector{T} params::CuVector{T}
weights::NTuple{2,CuMatrix{T}} weights::NTuple{2,CuMatrix{T}}
bias::CuVector{T} bias::CuVector{T}
reserve::CuVector{UInt8}
ptr::Ptr{Void} ptr::Ptr{Void}
end end
@ -86,7 +85,7 @@ function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T
w = cuzeros(T, rnnParamSize(T, d[], input)) w = cuzeros(T, rnnParamSize(T, d[], input))
# TODO: avoid reserve allocation here # TODO: avoid reserve allocation here
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., CuVector{UInt8}(1), d[]) rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., d[])
finalizer(rd, x -> finalizer(rd, x ->
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},),x)) @check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},),x))
return rd return rd
@ -116,14 +115,9 @@ function rnnTrainingReserveSize(r::RNNDesc, seqlen, xdesc)
return Int(size[]) return Int(size[])
end end
function getreserve(r::RNNDesc, seqlen, xdesc)
sz = rnnTrainingReserveSize(r, seqlen, xdesc)
sz length(r.reserve) ? r.reserve : (r.reserve = CuVector{UInt8}(sz))
end
function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co, function cudnnRNNForward(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
workspace, reserve=nothing; train = (reserve nothing)) where T workspace, reserve=nothing) where T
if !train if reserve == nothing
@check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t, @check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t,
(Ptr{Void}, Ptr{Void}, Cint, (Ptr{Void}, Ptr{Void}, Cint,
Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T},
@ -158,7 +152,7 @@ hBatch(x::AbstractVector, h::CuVector) = h
hBatch(x::AbstractMatrix, h::CuVector) = h .* cuones(1, size(x, 2)) hBatch(x::AbstractMatrix, h::CuVector) = h .* cuones(1, size(x, 2))
hBatch(x::AbstractMatrix, h::CuMatrix) = h .* cuones(1, size(h,2) == 1 ? size(x,2) : 1) hBatch(x::AbstractMatrix, h::CuMatrix) = h .* cuones(1, size(h,2) == 1 ? size(x,2) : 1)
function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing; train = false) where T function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing, train = Val{false}) where T
h = hBatch(x, h_) h = hBatch(x, h_)
c = c_ == nothing ? nothing : hBatch(x, c_) c = c_ == nothing ? nothing : hBatch(x, c_)
@assert size(x, 1) == rnn.input @assert size(x, 1) == rnn.input
@ -170,7 +164,9 @@ function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing; t
ho = similar(h) ho = similar(h)
ydesc = xDesc(y) ydesc = xDesc(y)
workspace = getworkspace(rnn, seqLength, xdesc) workspace = getworkspace(rnn, seqLength, xdesc)
reserve = train ? getreserve(rnn, seqLength, xdesc) : rnn.reserve reserve = train == Val{true} ?
CuVector{UInt8}(rnnTrainingReserveSize(rnn, seqLength, xdesc)) :
nothing
co = c == nothing ? c : similar(c) co = c == nothing ? c : similar(c)
cudnnRNNForward(rnn, seqLength, cudnnRNNForward(rnn, seqLength,
xdesc, x, xdesc, x,
@ -180,10 +176,14 @@ function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing; t
ydesc, y, ydesc, y,
hDesc(ho)..., hDesc(ho)...,
hDesc(co)..., hDesc(co)...,
workspace, reserve, train = train) workspace, reserve)
return c == nothing ? (y, ho) : (y, ho, co) result = c == nothing ? (y, ho) : (y, ho, co)
return train == Val{true} ? (reserve, result) : result
end end
forwardTrain(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing) where T =
forward(rnn, x, h, c, Val{true})
function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco, function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, rs) where T
@check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t, @check ccall((:cudnnRNNBackwardData,libcudnn),cudnnStatus_t,
@ -196,7 +196,9 @@ function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho
wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs)) wd, w, hd, h, cd, c, dxd, dx, dhd, dh, dcd, dc, ws, length(ws), rs, length(rs))
end end
function backwardData(rnn::RNNDesc{T}, y, dy, dho, dco, h, c) where T function backwardData(rnn::RNNDesc{T}, y, dy_, dho, dco, h, c, reserve) where T
# Same as above, any more efficient way?
dy = dy_ isa Integer ? zeros(y) : dy_
yd = xDesc(y) yd = xDesc(y)
dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2)) dx = y isa AbstractVector ? similar(dy, rnn.input) : similar(dy, rnn.input, size(dy, 2))
dh = similar(h) dh = similar(h)
@ -205,12 +207,12 @@ function backwardData(rnn::RNNDesc{T}, y, dy, dho, dco, h, c) where T
yd, y, yd, dy, hDesc(dho)..., hDesc(dco)..., yd, y, yd, dy, hDesc(dho)..., hDesc(dco)...,
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params, FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
hDesc(h)..., hDesc(c)..., xDesc(dx), dx, hDesc(dh)..., hDesc(dc)..., hDesc(h)..., hDesc(c)..., xDesc(dx), dx, hDesc(dh)..., hDesc(dc)...,
workspace[], rnn.reserve) workspace[], reserve)
return c == nothing ? (dx, dh) : (dx, dh, dc) return c == nothing ? (dx, dh) : (dx, dh, dc)
end end
backwardData(rnn, y, dy, dho, hx) = backwardData(rnn, y, dy, dho, hx, reserve) =
backwardData(rnn, y, dy, dho, nothing, hx, nothing) backwardData(rnn, y, dy, dho, nothing, hx, nothing, reserve)
function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, dwd, dw, function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, dwd, dw,
workspace, reserve) where T workspace, reserve) where T
@ -226,12 +228,12 @@ function cudnnRNNBackwardWeights(rnn::RNNDesc{T}, seqlen, xd, x, hd, h, yd, y, d
workspace, length(workspace), dwd, dw, reserve, length(reserve)) workspace, length(workspace), dwd, dw, reserve, length(reserve))
end end
function backwardWeights(rnn::RNNDesc{T}, x, h, y) where T function backwardWeights(rnn::RNNDesc{T}, x, h, y, reserve) where T
dw = zeros(rnn.params) dw = zeros(rnn.params)
cudnnRNNBackwardWeights(rnn, 1, cudnnRNNBackwardWeights(rnn, 1,
xDesc(x), x, hDesc(h)..., xDesc(y), y, xDesc(x), x, hDesc(h)..., xDesc(y), y,
FilterDesc(T, (1, 1, length(dw))), dw, FilterDesc(T, (1, 1, length(dw))), dw,
workspace[], rnn.reserve) workspace[], reserve)
return params(dw, rnn.input, rnn.hidden, ngates(rnn)) return params(dw, rnn.input, rnn.hidden, ngates(rnn))
end end
@ -286,12 +288,19 @@ end
import Flux.Tracker: data, isleaf, istracked, track, back_, @back, unbroadcast import Flux.Tracker: data, isleaf, istracked, track, back_, @back, unbroadcast
# TODO: fix reserve space usage mutable struct RNNCall{R}
struct RNNCall{R}
rnn::R rnn::R
reserve::CuVector{UInt8}
RNNCall{R}(rnn::R) where R = new(rnn)
end end
(c::RNNCall)(args...) = forward(desc(c.rnn), args..., train = true) RNNCall(rnn) = RNNCall{typeof(rnn)}(rnn)
function (c::RNNCall)(args...)
rs, result = forwardTrain(desc(c.rnn), args...)
c.reserve = rs
return result
end
istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...)) istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...))
@ -331,10 +340,10 @@ function back_(m::RNNCall{<:Union{CuRNN,CuGRU}}, y_, Δ, x, h)
y, ho = y_ y, ho = y_
dy, dho = Δ dy, dho = Δ
h_ = hBatch(x, data(h)) h_ = hBatch(x, data(h))
dx, dh = backwardData(descs[m.rnn], y, dy, dho, h_) dx, dh = backwardData(descs[m.rnn], y, dy, dho, h_, m.reserve)
@back(x, dx) @back(x, dx)
@back(h, unbroadcast(h, dh)) @back(h, unbroadcast(h, dh))
(dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y) (dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y, m.reserve)
# We don't have to make this assumption, it's just slightly more complex. # We don't have to make this assumption, it's just slightly more complex.
@assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b))) @assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b)))
istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi) istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi)
@ -347,11 +356,11 @@ function back_(m::RNNCall{<:CuLSTM}, y_, Δ, x, h, c)
dy, dho, dco = Δ dy, dho, dco = Δ
h_ = hBatch(x, data(h)) h_ = hBatch(x, data(h))
c_ = hBatch(x, data(c)) c_ = hBatch(x, data(c))
dx, dh, dc = backwardData(descs[m.rnn], y, dy, dho, dco, h_, c_) dx, dh, dc = backwardData(descs[m.rnn], y, dy, dho, dco, h_, c_, m.reserve)
@back(x, dx) @back(x, dx)
@back(h, unbroadcast(h, dh)) @back(h, unbroadcast(h, dh))
@back(c, unbroadcast(h, dc)) @back(c, unbroadcast(h, dc))
(dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y) (dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y, m.reserve)
@assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b))) @assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b)))
istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi) istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi)
istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh) istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh)

View File

@ -130,7 +130,7 @@ end
function (m::LSTMCell)(h_, x) function (m::LSTMCell)(h_, x)
h, c = h_ # TODO: nicer syntax on 0.7 h, c = h_ # TODO: nicer syntax on 0.7
b, o = m.b, length(h) b, o = m.b, size(h, 1)
g = m.Wi*x .+ m.Wh*h .+ b g = m.Wi*x .+ m.Wh*h .+ b
input = σ.(gate(g, o, 1)) input = σ.(gate(g, o, 1))
forget = σ.(gate(g, o, 2)) forget = σ.(gate(g, o, 2))
@ -173,7 +173,7 @@ GRUCell(in, out; init = glorot_uniform) =
param(zeros(out*3)), param(initn(out))) param(zeros(out*3)), param(initn(out)))
function (m::GRUCell)(h, x) function (m::GRUCell)(h, x)
b, o = m.b, length(h) b, o = m.b, size(h, 1)
gx, gh = m.Wi*x, m.Wh*h gx, gh = m.Wi*x, m.Wh*h
r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1)) r = σ.(gate(gx, o, 1) .+ gate(gh, o, 1) .+ gate(b, o, 1))
z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2)) z = σ.(gate(gx, o, 2) .+ gate(gh, o, 2) .+ gate(b, o, 2))

View File

@ -1,7 +1,4 @@
using Flux, CuArrays, Base.Test using Flux, CuArrays, Base.Test
using Flux.CUDA
using Flux.CUDA: RNNDesc
using CUDAnative
info("Testing Flux/CUDNN") info("Testing Flux/CUDNN")
@ -11,8 +8,8 @@ info("Testing Flux/CUDNN")
cux = cu(x) cux = cu(x)
rnn = R(10, 5) rnn = R(10, 5)
curnn = mapleaves(cu, rnn) curnn = mapleaves(cu, rnn)
y = rnn(x) y = (rnn(x); rnn(x))
cuy = curnn(cux) cuy = (curnn(cux); curnn(cux))
@test y.data collect(cuy.data) @test y.data collect(cuy.data)
@test haskey(Flux.CUDA.descs, curnn.cell) @test haskey(Flux.CUDA.descs, curnn.cell)