avoid val

This commit is contained in:
Mike J Innes 2018-02-06 12:41:50 +00:00
parent 14086b8c2d
commit 07e1b1e0a9

View File

@ -109,8 +109,9 @@ function getreserve(r::RNNDesc, seqlen, xdesc)
sz length(r.reserve) ? r.reserve : (r.reserve = CuVector{UInt8}(sz)) sz length(r.reserve) ? r.reserve : (r.reserve = CuVector{UInt8}(sz))
end end
function cudnnRNNForward(::Type{T}, rnn, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co, workspace, reserve=nothing) where T function cudnnRNNForward(::Type{T}, rnn, seqlen, xd, x, hd, h, cd, c, wd, w, yd, y, hod, ho, cod, co,
if reserve == nothing workspace, reserve=nothing; train = (reserve nothing)) where T
if !train
@check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t, @check ccall((:cudnnRNNForwardInference, libcudnn), cudnnStatus_t,
(Ptr{Void}, Ptr{Void}, Cint, (Ptr{Void}, Ptr{Void}, Cint,
Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Ptr{Void}}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T},
@ -129,7 +130,7 @@ function cudnnRNNForward(::Type{T}, rnn, seqlen, xd, x, hd, h, cd, c, wd, w, yd,
end end
end end
function forward(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing; train = Val{false}) where T function forward(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing; train = false) where T
@assert size(x, 1) == rnn.input @assert size(x, 1) == rnn.input
@assert size(h, 1) == rnn.hidden @assert size(h, 1) == rnn.hidden
@assert size(x, 2) == size(h, 2) @assert size(x, 2) == size(h, 2)
@ -138,7 +139,7 @@ function forward(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing; tra
y = x isa AbstractVector ? similar(x, rnn.hidden) : similar(x, rnn.hidden, size(x, 2)) y = x isa AbstractVector ? similar(x, rnn.hidden) : similar(x, rnn.hidden, size(x, 2))
ydesc = [TensorDesc(T, (1, size(y, 1), size(y, 2)))] ydesc = [TensorDesc(T, (1, size(y, 1), size(y, 2)))]
workspace = CuVector{UInt8}(rnnWorkspaceSize(rnn, seqLength, xdesc)) # TODO: reuse this workspace = CuVector{UInt8}(rnnWorkspaceSize(rnn, seqLength, xdesc)) # TODO: reuse this
reserve = train == Val{true} ? getreserve(rnn, seqLength, xdesc) : nothing reserve = train ? getreserve(rnn, seqLength, xdesc) : rnn.reserve
if c nothing if c nothing
@assert size(c, 1) == rnn.hidden @assert size(c, 1) == rnn.hidden
@assert size(c, 2) == size(h, 2) @assert size(c, 2) == size(h, 2)
@ -157,7 +158,7 @@ function forward(rnn::RNNDesc{T}, x::CuArray{T}, h::CuArray{T}, c = nothing; tra
ydesc, y, ydesc, y,
C_NULL, C_NULL, # hout C_NULL, C_NULL, # hout
coutdesc, cout, coutdesc, cout,
workspace, reserve) workspace, reserve, train = train)
if c == nothing if c == nothing
return y, y return y, y
else else
@ -217,16 +218,16 @@ end
istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...)) istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...))
function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64} function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
y, h = forward(desc(m), Flux.data(x), Flux.data(h), train = Val{istrain(m, h, x)}) y, h = forward(desc(m), Flux.data(x), Flux.data(h), train = istrain(m, h, x))
return h, y return h, y
end end
function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64} function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
y, h = forward(desc(m), Flux.data(x), Flux.data(h), train = Val{istrain(m, h, x)}) y, h = forward(desc(m), Flux.data(x), Flux.data(h), train = istrain(m, h, x))
return h, y return h, y
end end
function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64} function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64}
y, h, c = forward(desc(m), Flux.data(x), Flux.data.(h)..., train = Val{istrain(m, h, x)}) y, h, c = forward(desc(m), Flux.data(x), Flux.data.(h)..., train = istrain(m, h, x))
return (h, c), y return (h, c), y
end end