rewrite tests
This commit is contained in:
parent
d592f4e327
commit
bc452fcd81
@ -70,6 +70,9 @@ function rnnParamSize(T, r, input)
|
|||||||
return Int(size[])÷sizeof(T)
|
return Int(size[])÷sizeof(T)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
ngates(mode) = [1, 1, 4, 3][mode+1]
|
||||||
|
ngates(r::RNNDesc) = ngates(r.mode)
|
||||||
|
|
||||||
function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T
|
function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T
|
||||||
d = [C_NULL]
|
d = [C_NULL]
|
||||||
@check ccall((:cudnnCreateRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Ptr{Void}},),d)
|
@check ccall((:cudnnCreateRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Ptr{Void}},),d)
|
||||||
@ -81,10 +84,9 @@ function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T
|
|||||||
@check ccall((:cudnnSetRNNDescriptor_v6,libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Cint,Cint,Ptr{Void},Cint,Cint,Cint,Cint,Cint),
|
@check ccall((:cudnnSetRNNDescriptor_v6,libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Cint,Cint,Ptr{Void},Cint,Cint,Cint,Cint,Cint),
|
||||||
libcudnn_handle[],d[],hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T))
|
libcudnn_handle[],d[],hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T))
|
||||||
|
|
||||||
w = cuzeros(T, rnnParamSize(T, d[], 10))
|
w = cuzeros(T, rnnParamSize(T, d[], input))
|
||||||
ngates = [1, 1, 4, 3][mode+1]
|
|
||||||
# TODO: avoid reserve allocation here
|
# TODO: avoid reserve allocation here
|
||||||
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates)..., CuVector{UInt8}(1), d[])
|
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., CuVector{UInt8}(1), d[])
|
||||||
finalizer(rd, x ->
|
finalizer(rd, x ->
|
||||||
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},),x))
|
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},),x))
|
||||||
return rd
|
return rd
|
||||||
@ -165,6 +167,7 @@ function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing; t
|
|||||||
seqLength = 1
|
seqLength = 1
|
||||||
xdesc = xDesc(x)
|
xdesc = xDesc(x)
|
||||||
y = x isa AbstractVector ? similar(x, rnn.hidden) : similar(x, rnn.hidden, size(x, 2))
|
y = x isa AbstractVector ? similar(x, rnn.hidden) : similar(x, rnn.hidden, size(x, 2))
|
||||||
|
ho = similar(h)
|
||||||
ydesc = xDesc(y)
|
ydesc = xDesc(y)
|
||||||
workspace = getworkspace(rnn, seqLength, xdesc)
|
workspace = getworkspace(rnn, seqLength, xdesc)
|
||||||
reserve = train ? getreserve(rnn, seqLength, xdesc) : rnn.reserve
|
reserve = train ? getreserve(rnn, seqLength, xdesc) : rnn.reserve
|
||||||
@ -175,10 +178,10 @@ function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing; t
|
|||||||
hDesc(c)...,
|
hDesc(c)...,
|
||||||
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
|
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
|
||||||
ydesc, y,
|
ydesc, y,
|
||||||
C_NULL, C_NULL, # hout
|
hDesc(ho)...,
|
||||||
hDesc(co)...,
|
hDesc(co)...,
|
||||||
workspace, reserve, train = train)
|
workspace, reserve, train = train)
|
||||||
return c == nothing ? (y, y) : (y, y, co)
|
return c == nothing ? (y, ho) : (y, ho, co)
|
||||||
end
|
end
|
||||||
|
|
||||||
function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
|
function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
|
||||||
@ -229,7 +232,7 @@ function backwardWeights(rnn::RNNDesc{T}, x, h, y) where T
|
|||||||
xDesc(x), x, hDesc(h)..., xDesc(y), y,
|
xDesc(x), x, hDesc(h)..., xDesc(y), y,
|
||||||
FilterDesc(T, (1, 1, length(dw))), dw,
|
FilterDesc(T, (1, 1, length(dw))), dw,
|
||||||
workspace[], rnn.reserve)
|
workspace[], rnn.reserve)
|
||||||
return params(dw, rnn.input, rnn.hidden)
|
return params(dw, rnn.input, rnn.hidden, ngates(rnn))
|
||||||
end
|
end
|
||||||
|
|
||||||
# Interface
|
# Interface
|
||||||
@ -283,6 +286,7 @@ end
|
|||||||
|
|
||||||
import Flux.Tracker: data, isleaf, istracked, track, back_, @back, unbroadcast
|
import Flux.Tracker: data, isleaf, istracked, track, back_, @back, unbroadcast
|
||||||
|
|
||||||
|
# TODO: fix reserve space usage
|
||||||
struct RNNCall{R}
|
struct RNNCall{R}
|
||||||
rnn::R
|
rnn::R
|
||||||
end
|
end
|
||||||
|
@ -5,65 +5,30 @@ using CUDAnative
|
|||||||
|
|
||||||
info("Testing Flux/CUDNN")
|
info("Testing Flux/CUDNN")
|
||||||
|
|
||||||
function randinit(r::RNNDesc{T}) where T
|
@testset "RNN" begin
|
||||||
for w in (r.weights..., r.bias)
|
@testset for R in [RNN, GRU, LSTM]
|
||||||
copy!(w, randn(T, size(w)))
|
x = param(rand(10,5))
|
||||||
|
cux = cu(x)
|
||||||
|
rnn = R(10, 5)
|
||||||
|
curnn = mapleaves(cu, rnn)
|
||||||
|
y = rnn(x)
|
||||||
|
cuy = curnn(cux)
|
||||||
|
|
||||||
|
@test y.data ≈ collect(cuy.data)
|
||||||
|
@test haskey(Flux.CUDA.descs, curnn.cell)
|
||||||
|
|
||||||
|
Δ = randn(size(y))
|
||||||
|
|
||||||
|
Flux.back!(y, Δ)
|
||||||
|
Flux.back!(cuy, cu(Δ))
|
||||||
|
|
||||||
|
@test x.grad ≈ collect(cux.grad)
|
||||||
|
@test rnn.cell.Wi.grad ≈ collect(curnn.cell.Wi.grad)
|
||||||
|
@test rnn.cell.Wh.grad ≈ collect(curnn.cell.Wh.grad)
|
||||||
|
@test rnn.cell.b.grad ≈ collect(curnn.cell.b.grad)
|
||||||
|
@test rnn.cell.h.grad ≈ collect(curnn.cell.h.grad)
|
||||||
|
if isdefined(rnn.cell, :c)
|
||||||
|
@test rnn.cell.c.grad ≈ collect(curnn.cell.c.grad)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
const cutanh = CUDAnative.tanh
|
|
||||||
|
|
||||||
gate(rnn, x, n) = x[(1:rnn.hidden) + rnn.hidden*(n-1)]
|
|
||||||
|
|
||||||
function test_forward(rnn::RNNDesc, x, h, c = nothing)
|
|
||||||
if rnn.mode == CUDA.RNN_RELU
|
|
||||||
Wx, Wh = rnn.weights
|
|
||||||
b = rnn.bias
|
|
||||||
h′ = relu.(Wx'*x .+ Wh'*h .+ b)
|
|
||||||
return h′, h′
|
|
||||||
elseif rnn.mode == CUDA.GRU
|
|
||||||
Wx, Wh = rnn.weights
|
|
||||||
b = rnn.bias
|
|
||||||
gx, gh = Wx'*x, Wh'*h
|
|
||||||
r = σ.(gate(rnn, gx, 1) .+ gate(rnn, gh, 1) .+ gate(rnn, b, 1))
|
|
||||||
z = σ.(gate(rnn, gx, 2) .+ gate(rnn, gh, 2) .+ gate(rnn, b, 2))
|
|
||||||
h̃ = cutanh.(gate(rnn, gx, 3) .+ r .* gate(rnn, gh, 3) .+ gate(rnn, b, 3))
|
|
||||||
h′ = (1.-z).*h̃ .+ z.*h
|
|
||||||
return h′, h′
|
|
||||||
elseif rnn.mode == CUDA.LSTM
|
|
||||||
Wx, Wh = rnn.weights
|
|
||||||
b = rnn.bias
|
|
||||||
g = Wx'*x .+ Wh'*h .+ b
|
|
||||||
input = σ.(gate(rnn, g, 1))
|
|
||||||
forget = σ.(gate(rnn, g, 2))
|
|
||||||
cell = cutanh.(gate(rnn, g, 3))
|
|
||||||
output = σ.(gate(rnn, g, 4))
|
|
||||||
c = forget .* c .+ input .* cell
|
|
||||||
h = output .* cutanh.(c)
|
|
||||||
return (h, h, c)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
@testset "CUDNN" begin
|
|
||||||
|
|
||||||
rnn = RNNDesc{Float32}(CUDA.RNN_RELU, 10, 5)
|
|
||||||
randinit(rnn)
|
|
||||||
x, h = cu(rand(10)), cu(rand(5))
|
|
||||||
@test collect(test_forward(rnn, x, h)[1]) ≈
|
|
||||||
collect(CUDA.forwardInference(rnn, x, h)[1])
|
|
||||||
|
|
||||||
rnn = RNNDesc{Float32}(CUDA.GRU, 10, 5)
|
|
||||||
randinit(rnn)
|
|
||||||
x, h = cu(rand(10)), cu(rand(5))
|
|
||||||
@test collect(test_forward(rnn, x, h)[1]) ≈
|
|
||||||
collect(CUDA.forwardInference(rnn, x, h)[1])
|
|
||||||
|
|
||||||
rnn = RNNDesc{Float32}(CUDA.LSTM, 10, 5)
|
|
||||||
randinit(rnn)
|
|
||||||
x, h, c = cu(rand(10)), cu(rand(5)), cu(rand(5))
|
|
||||||
@test collect(test_forward(rnn, x, h, c)[1]) ≈
|
|
||||||
collect(CUDA.forwardInference(rnn, x, h, c)[1])
|
|
||||||
@test collect(test_forward(rnn, x, h, c)[2]) ≈
|
|
||||||
collect(CUDA.forwardInference(rnn, x, h, c)[2])
|
|
||||||
|
|
||||||
end
|
|
||||||
|
Loading…
Reference in New Issue
Block a user