train forward pass
This commit is contained in:
parent
9a6fcf057b
commit
14086b8c2d
@ -57,6 +57,7 @@ mutable struct RNNDesc{T}
|
|||||||
params::CuVector{T}
|
params::CuVector{T}
|
||||||
weights::NTuple{2,CuMatrix{T}}
|
weights::NTuple{2,CuMatrix{T}}
|
||||||
bias::CuVector{T}
|
bias::CuVector{T}
|
||||||
|
reserve::CuVector{UInt8}
|
||||||
ptr::Ptr{Void}
|
ptr::Ptr{Void}
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -82,7 +83,8 @@ function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T
|
|||||||
|
|
||||||
w = cuzeros(T, rnnParamSize(T, d[], 10))
|
w = cuzeros(T, rnnParamSize(T, d[], 10))
|
||||||
ngates = [1, 1, 4, 3][mode+1]
|
ngates = [1, 1, 4, 3][mode+1]
|
||||||
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates)..., d[])
|
# TODO: avoid reserve allocation here
|
||||||
|
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates)..., CuVector{UInt8}(1), d[])
|
||||||
finalizer(rd, x ->
|
finalizer(rd, x ->
|
||||||
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},),x))
|
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},),x))
|
||||||
return rd
|
return rd
|
||||||
@ -102,49 +104,64 @@ function rnnTrainingReserveSize(r::RNNDesc, seqlen, xdesc)
|
|||||||
return Int(size[])
|
return Int(size[])
|
||||||
end
|
end
|
||||||
|
|
||||||
function forwardInference(rnn::RNNDesc{T}, x, h, c = nothing) where T
|
function getreserve(r::RNNDesc, seqlen, xdesc)
|
||||||
|
sz = rnnTrainingReserveSize(r, seqlen, xdesc)
|
||||||
|
sz ≤ length(r.reserve) ? r.reserve : (r.reserve = CuVector{UInt8}(sz))
|
||||||
|
end
|
||||||
|
|
||||||
|
function cudnnRNNForward(::Type{T}, rnn, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co, workspace, reserve=nothing) where T
|
||||||
|
if reserve == nothing
|
||||||
|
@check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t,
|
||||||
|
(Ptr{Void}, Ptr{Void}, Cint,
|
||||||
|
Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T},
|
||||||
|
Ptr{Void}, Csize_t),
|
||||||
|
libcudnn_handle[], rnn, seqlen,
|
||||||
|
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
||||||
|
workspace, length(workspace))
|
||||||
|
else
|
||||||
|
@check ccall((:cudnnRNNForwardTraining, libcudnn), cudnnStatus_t,
|
||||||
|
(Ptr{Void}, Ptr{Void}, Cint,
|
||||||
|
Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T},
|
||||||
|
Ptr{Void}, Csize_t, Ptr{Void}, Csize_t),
|
||||||
|
libcudnn_handle[], rnn, seqlen,
|
||||||
|
xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
|
||||||
|
workspace, length(workspace), reserve, length(reserve))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
function forward(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing; train = Val{false}) where T
|
||||||
@assert size(x, 1) == rnn.input
|
@assert size(x, 1) == rnn.input
|
||||||
@assert size(h, 1) == rnn.hidden
|
@assert size(h, 1) == rnn.hidden
|
||||||
@assert size(x, 2) == size(h, 2)
|
@assert size(x, 2) == size(h, 2)
|
||||||
seqLength = 1
|
seqLength = 1
|
||||||
xdesc = [TensorDesc(reshape(x, 1, size(x, 1), size(x, 2)))]
|
xdesc = [TensorDesc(T, (1, size(x, 1), size(x, 2)))]
|
||||||
y = x isa AbstractVector ? similar(x, rnn.hidden) : similar(x, rnn.hidden, size(x, 2))
|
y = x isa AbstractVector ? similar(x, rnn.hidden) : similar(x, rnn.hidden, size(x, 2))
|
||||||
ydesc = [TensorDesc(reshape(y, 1, size(y, 1), size(y, 2)))]
|
ydesc = [TensorDesc(T, (1, size(y, 1), size(y, 2)))]
|
||||||
hout = similar(h)
|
|
||||||
workspace = CuVector{UInt8}(rnnWorkspaceSize(rnn, seqLength, xdesc)) # TODO: reuse this
|
workspace = CuVector{UInt8}(rnnWorkspaceSize(rnn, seqLength, xdesc)) # TODO: reuse this
|
||||||
|
reserve = train == Val{true} ? getreserve(rnn, seqLength, xdesc) : nothing
|
||||||
if c ≠ nothing
|
if c ≠ nothing
|
||||||
@assert size(c, 1) == rnn.hidden
|
@assert size(c, 1) == rnn.hidden
|
||||||
@assert size(c, 2) == size(h, 2)
|
@assert size(c, 2) == size(h, 2)
|
||||||
cptr = c
|
cptr = c
|
||||||
cdesc = TensorDesc(reshape(c, size(c, 1), size(c, 2), 1))
|
cdesc = TensorDesc(T, (size(c, 1), size(c, 2), 1))
|
||||||
cout = similar(c)
|
cout = similar(c)
|
||||||
coutdesc = TensorDesc(reshape(cout, size(cout, 1), size(cout, 2), 1))
|
coutdesc = TensorDesc(T, (size(cout, 1), size(cout, 2), 1))
|
||||||
else
|
else
|
||||||
cptr = cdesc = cout = coutdesc = C_NULL
|
cptr = cdesc = cout = coutdesc = C_NULL
|
||||||
end
|
end
|
||||||
@check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t,
|
cudnnRNNForward(T, rnn, seqLength,
|
||||||
(Ptr{Void}, Ptr{Void}, Cint,
|
|
||||||
Ptr{Ptr{Void}}, Ptr{T},
|
|
||||||
Ptr{Void}, Ptr{T},
|
|
||||||
Ptr{Void}, Ptr{T},
|
|
||||||
Ptr{Void}, Ptr{T},
|
|
||||||
Ptr{Ptr{Void}}, Ptr{T},
|
|
||||||
Ptr{Void}, Ptr{T},
|
|
||||||
Ptr{Void}, Ptr{T},
|
|
||||||
Ptr{Void}, Csize_t),
|
|
||||||
libcudnn_handle[], rnn, seqLength,
|
|
||||||
xdesc, x,
|
xdesc, x,
|
||||||
TensorDesc(reshape(h, size(h, 1), size(h, 2), 1)), h,
|
TensorDesc(T, (size(h, 1), size(h, 2), 1)), h,
|
||||||
cdesc, cptr,
|
cdesc, cptr,
|
||||||
TensorDesc(reshape(rnn.params, 1, 1, :)), rnn.params,
|
TensorDesc(T, (1, 1, length(rnn.params))), rnn.params,
|
||||||
ydesc, y,
|
ydesc, y,
|
||||||
TensorDesc(reshape(hout, size(hout, 1), size(hout, 2), 1)), hout,
|
C_NULL, C_NULL, # hout
|
||||||
coutdesc, cout,
|
coutdesc, cout,
|
||||||
workspace, length(workspace))
|
workspace, reserve)
|
||||||
if c == nothing
|
if c == nothing
|
||||||
return y, hout
|
return y, y
|
||||||
else
|
else
|
||||||
return y, hout, cout
|
return y, y, cout
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -197,17 +214,19 @@ function desc(rnn)
|
|||||||
return d
|
return d
|
||||||
end
|
end
|
||||||
|
|
||||||
|
istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...))
|
||||||
|
|
||||||
function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
||||||
y, h = forwardInference(desc(m), Flux.data(x), Flux.data(h))
|
y, h = forward(desc(m), Flux.data(x), Flux.data(h), train = Val{istrain(m, h, x)})
|
||||||
return h, y
|
return h, y
|
||||||
end
|
end
|
||||||
|
|
||||||
function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
||||||
y, h = forwardInference(desc(m), Flux.data(x), Flux.data(h))
|
y, h = forward(desc(m), Flux.data(x), Flux.data(h), train = Val{istrain(m, h, x)})
|
||||||
return h, y
|
return h, y
|
||||||
end
|
end
|
||||||
|
|
||||||
function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64}
|
||||||
y, h, c = forwardInference(desc(m), Flux.data(x), Flux.data.(h)...)
|
y, h, c = forward(desc(m), Flux.data(x), Flux.data.(h)..., train = Val{istrain(m, h, x)})
|
||||||
return (h, c), y
|
return (h, c), y
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user