From 14086b8c2d61cdd8d06dde60783bc39158a8f26e Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Fri, 2 Feb 2018 17:48:08 +0000 Subject: [PATCH] train forward pass --- src/cuda/cudnn.jl | 81 +++++++++++++++++++++++++++++------------------ 1 file changed, 50 insertions(+), 31 deletions(-) diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index 5ecb8cf0..8fffa581 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -57,6 +57,7 @@ mutable struct RNNDesc{T} params::CuVector{T} weights::NTuple{2,CuMatrix{T}} bias::CuVector{T} + reserve::CuVector{UInt8} ptr::Ptr{Void} end @@ -82,7 +83,8 @@ function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T w = cuzeros(T, rnnParamSize(T, d[], 10)) ngates = [1, 1, 4, 3][mode+1] - rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates)..., d[]) + # TODO: avoid reserve allocation here + rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates)..., CuVector{UInt8}(1), d[]) finalizer(rd, x -> @check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},),x)) return rd @@ -102,49 +104,64 @@ function rnnTrainingReserveSize(r::RNNDesc, seqlen, xdesc) return Int(size[]) end -function forwardInference(rnn::RNNDesc{T}, x, h, c = nothing) where T +function getreserve(r::RNNDesc, seqlen, xdesc) + sz = rnnTrainingReserveSize(r, seqlen, xdesc) + sz ≤ length(r.reserve) ? r.reserve : (r.reserve = CuVector{UInt8}(sz)) +end + +function cudnnRNNForward(::Type{T}, rnn, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co, workspace, reserve=nothing) where T + if reserve == nothing + @check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t, + (Ptr{Void}, Ptr{Void}, Cint, + Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, + Ptr{Void}, Csize_t), + libcudnn_handle[], rnn, seqlen, + xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co, + workspace, length(workspace)) + else + @check ccall((:cudnnRNNForwardTraining, libcudnn), cudnnStatus_t, + (Ptr{Void}, Ptr{Void}, Cint, + Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, + Ptr{Void}, Csize_t, Ptr{Void}, Csize_t), + libcudnn_handle[], rnn, seqlen, + xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co, + workspace, length(workspace), reserve, length(reserve)) + end +end + +function forward(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing; train = Val{false}) where T @assert size(x, 1) == rnn.input @assert size(h, 1) == rnn.hidden @assert size(x, 2) == size(h, 2) seqLength = 1 - xdesc = [TensorDesc(reshape(x, 1, size(x, 1), size(x, 2)))] + xdesc = [TensorDesc(T, (1, size(x, 1), size(x, 2)))] y = x isa AbstractVector ? similar(x, rnn.hidden) : similar(x, rnn.hidden, size(x, 2)) - ydesc = [TensorDesc(reshape(y, 1, size(y, 1), size(y, 2)))] - hout = similar(h) + ydesc = [TensorDesc(T, (1, size(y, 1), size(y, 2)))] workspace = CuVector{UInt8}(rnnWorkspaceSize(rnn, seqLength, xdesc)) # TODO: reuse this + reserve = train == Val{true} ? getreserve(rnn, seqLength, xdesc) : nothing if c ≠ nothing @assert size(c, 1) == rnn.hidden @assert size(c, 2) == size(h, 2) cptr = c - cdesc = TensorDesc(reshape(c, size(c, 1), size(c, 2), 1)) + cdesc = TensorDesc(T, (size(c, 1), size(c, 2), 1)) cout = similar(c) - coutdesc = TensorDesc(reshape(cout, size(cout, 1), size(cout, 2), 1)) + coutdesc = TensorDesc(T, (size(cout, 1), size(cout, 2), 1)) else cptr = cdesc = cout = coutdesc = C_NULL end - @check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t, - (Ptr{Void}, Ptr{Void}, Cint, - Ptr{Ptr{Void}}, Ptr{T}, - Ptr{Void}, Ptr{T}, - Ptr{Void}, Ptr{T}, - Ptr{Void}, Ptr{T}, - Ptr{Ptr{Void}}, Ptr{T}, - Ptr{Void}, Ptr{T}, - Ptr{Void}, Ptr{T}, - Ptr{Void}, Csize_t), - libcudnn_handle[], rnn, seqLength, - xdesc, x, - TensorDesc(reshape(h, size(h, 1), size(h, 2), 1)), h, - cdesc, cptr, - TensorDesc(reshape(rnn.params, 1, 1, :)), rnn.params, - ydesc, y, - TensorDesc(reshape(hout, size(hout, 1), size(hout, 2), 1)), hout, - coutdesc, cout, - workspace, length(workspace)) + cudnnRNNForward(T, rnn, seqLength, + xdesc, x, + TensorDesc(T, (size(h, 1), size(h, 2), 1)), h, + cdesc, cptr, + TensorDesc(T, (1, 1, length(rnn.params))), rnn.params, + ydesc, y, + C_NULL, C_NULL, # hout + coutdesc, cout, + workspace, reserve) if c == nothing - return y, hout + return y, y else - return y, hout, cout + return y, y, cout end end @@ -197,17 +214,19 @@ function desc(rnn) return d end +istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...)) + function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64} - y, h = forwardInference(desc(m), Flux.data(x), Flux.data(h)) + y, h = forward(desc(m), Flux.data(x), Flux.data(h), train = Val{istrain(m, h, x)}) return h, y end function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64} - y, h = forwardInference(desc(m), Flux.data(x), Flux.data(h)) + y, h = forward(desc(m), Flux.data(x), Flux.data(h), train = Val{istrain(m, h, x)}) return h, y end function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64} - y, h, c = forwardInference(desc(m), Flux.data(x), Flux.data.(h)...) + y, h, c = forward(desc(m), Flux.data(x), Flux.data.(h)..., train = Val{istrain(m, h, x)}) return (h, c), y end