rewrite tests
This commit is contained in:
parent
d592f4e327
commit
bc452fcd81
|
@ -70,6 +70,9 @@ function rnnParamSize(T, r, input)
|
|||
return Int(size[])÷sizeof(T)
|
||||
end
|
||||
|
||||
ngates(mode) = [1, 1, 4, 3][mode+1]
|
||||
ngates(r::RNNDesc) = ngates(r.mode)
|
||||
|
||||
function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T
|
||||
d = [C_NULL]
|
||||
@check ccall((:cudnnCreateRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Ptr{Void}},),d)
|
||||
|
@ -81,10 +84,9 @@ function RNNDesc{T}(mode::Int, input::Int, hidden::Int; layers = 1) where T
|
|||
@check ccall((:cudnnSetRNNDescriptor_v6,libcudnn), cudnnStatus_t, (Ptr{Void},Ptr{Void},Cint,Cint,Ptr{Void},Cint,Cint,Cint,Cint,Cint),
|
||||
libcudnn_handle[],d[],hidden,layers,dropoutDesc,inputMode,direction,mode,algo,cudnnDataType(T))
|
||||
|
||||
w = cuzeros(T, rnnParamSize(T, d[], 10))
|
||||
ngates = [1, 1, 4, 3][mode+1]
|
||||
w = cuzeros(T, rnnParamSize(T, d[], input))
|
||||
# TODO: avoid reserve allocation here
|
||||
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates)..., CuVector{UInt8}(1), d[])
|
||||
rd = RNNDesc{T}(mode, input, hidden, w, params(w, input, hidden, ngates(mode))..., CuVector{UInt8}(1), d[])
|
||||
finalizer(rd, x ->
|
||||
@check ccall((:cudnnDestroyRNNDescriptor,libcudnn),cudnnStatus_t,(Ptr{Void},),x))
|
||||
return rd
|
||||
|
@ -165,6 +167,7 @@ function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing; t
|
|||
seqLength = 1
|
||||
xdesc = xDesc(x)
|
||||
y = x isa AbstractVector ? similar(x, rnn.hidden) : similar(x, rnn.hidden, size(x, 2))
|
||||
ho = similar(h)
|
||||
ydesc = xDesc(y)
|
||||
workspace = getworkspace(rnn, seqLength, xdesc)
|
||||
reserve = train ? getreserve(rnn, seqLength, xdesc) : rnn.reserve
|
||||
|
@ -175,10 +178,10 @@ function forward(rnn::RNNDesc{T}, x::CuArray{T}, h_::CuArray{T}, c_ = nothing; t
|
|||
hDesc(c)...,
|
||||
FilterDesc(T, (1, 1, length(rnn.params))), rnn.params,
|
||||
ydesc, y,
|
||||
C_NULL, C_NULL, # hout
|
||||
hDesc(ho)...,
|
||||
hDesc(co)...,
|
||||
workspace, reserve, train = train)
|
||||
return c == nothing ? (y, y) : (y, y, co)
|
||||
return c == nothing ? (y, ho) : (y, ho, co)
|
||||
end
|
||||
|
||||
function cudnnRNNBackwardData(rnn::RNNDesc{T}, seqlen, yd, y, dyd, dy, dhod, dho, dcod, dco,
|
||||
|
@ -229,7 +232,7 @@ function backwardWeights(rnn::RNNDesc{T}, x, h, y) where T
|
|||
xDesc(x), x, hDesc(h)..., xDesc(y), y,
|
||||
FilterDesc(T, (1, 1, length(dw))), dw,
|
||||
workspace[], rnn.reserve)
|
||||
return params(dw, rnn.input, rnn.hidden)
|
||||
return params(dw, rnn.input, rnn.hidden, ngates(rnn))
|
||||
end
|
||||
|
||||
# Interface
|
||||
|
@ -283,6 +286,7 @@ end
|
|||
|
||||
import Flux.Tracker: data, isleaf, istracked, track, back_, @back, unbroadcast
|
||||
|
||||
# TODO: fix reserve space usage
|
||||
struct RNNCall{R}
|
||||
rnn::R
|
||||
end
|
||||
|
|
|
@ -5,65 +5,30 @@ using CUDAnative
|
|||
|
||||
info("Testing Flux/CUDNN")
|
||||
|
||||
function randinit(r::RNNDesc{T}) where T
|
||||
for w in (r.weights..., r.bias)
|
||||
copy!(w, randn(T, size(w)))
|
||||
@testset "RNN" begin
|
||||
@testset for R in [RNN, GRU, LSTM]
|
||||
x = param(rand(10,5))
|
||||
cux = cu(x)
|
||||
rnn = R(10, 5)
|
||||
curnn = mapleaves(cu, rnn)
|
||||
y = rnn(x)
|
||||
cuy = curnn(cux)
|
||||
|
||||
@test y.data ≈ collect(cuy.data)
|
||||
@test haskey(Flux.CUDA.descs, curnn.cell)
|
||||
|
||||
Δ = randn(size(y))
|
||||
|
||||
Flux.back!(y, Δ)
|
||||
Flux.back!(cuy, cu(Δ))
|
||||
|
||||
@test x.grad ≈ collect(cux.grad)
|
||||
@test rnn.cell.Wi.grad ≈ collect(curnn.cell.Wi.grad)
|
||||
@test rnn.cell.Wh.grad ≈ collect(curnn.cell.Wh.grad)
|
||||
@test rnn.cell.b.grad ≈ collect(curnn.cell.b.grad)
|
||||
@test rnn.cell.h.grad ≈ collect(curnn.cell.h.grad)
|
||||
if isdefined(rnn.cell, :c)
|
||||
@test rnn.cell.c.grad ≈ collect(curnn.cell.c.grad)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
const cutanh = CUDAnative.tanh
|
||||
|
||||
gate(rnn, x, n) = x[(1:rnn.hidden) + rnn.hidden*(n-1)]
|
||||
|
||||
function test_forward(rnn::RNNDesc, x, h, c = nothing)
|
||||
if rnn.mode == CUDA.RNN_RELU
|
||||
Wx, Wh = rnn.weights
|
||||
b = rnn.bias
|
||||
h′ = relu.(Wx'*x .+ Wh'*h .+ b)
|
||||
return h′, h′
|
||||
elseif rnn.mode == CUDA.GRU
|
||||
Wx, Wh = rnn.weights
|
||||
b = rnn.bias
|
||||
gx, gh = Wx'*x, Wh'*h
|
||||
r = σ.(gate(rnn, gx, 1) .+ gate(rnn, gh, 1) .+ gate(rnn, b, 1))
|
||||
z = σ.(gate(rnn, gx, 2) .+ gate(rnn, gh, 2) .+ gate(rnn, b, 2))
|
||||
h̃ = cutanh.(gate(rnn, gx, 3) .+ r .* gate(rnn, gh, 3) .+ gate(rnn, b, 3))
|
||||
h′ = (1.-z).*h̃ .+ z.*h
|
||||
return h′, h′
|
||||
elseif rnn.mode == CUDA.LSTM
|
||||
Wx, Wh = rnn.weights
|
||||
b = rnn.bias
|
||||
g = Wx'*x .+ Wh'*h .+ b
|
||||
input = σ.(gate(rnn, g, 1))
|
||||
forget = σ.(gate(rnn, g, 2))
|
||||
cell = cutanh.(gate(rnn, g, 3))
|
||||
output = σ.(gate(rnn, g, 4))
|
||||
c = forget .* c .+ input .* cell
|
||||
h = output .* cutanh.(c)
|
||||
return (h, h, c)
|
||||
end
|
||||
end
|
||||
|
||||
@testset "CUDNN" begin
|
||||
|
||||
rnn = RNNDesc{Float32}(CUDA.RNN_RELU, 10, 5)
|
||||
randinit(rnn)
|
||||
x, h = cu(rand(10)), cu(rand(5))
|
||||
@test collect(test_forward(rnn, x, h)[1]) ≈
|
||||
collect(CUDA.forwardInference(rnn, x, h)[1])
|
||||
|
||||
rnn = RNNDesc{Float32}(CUDA.GRU, 10, 5)
|
||||
randinit(rnn)
|
||||
x, h = cu(rand(10)), cu(rand(5))
|
||||
@test collect(test_forward(rnn, x, h)[1]) ≈
|
||||
collect(CUDA.forwardInference(rnn, x, h)[1])
|
||||
|
||||
rnn = RNNDesc{Float32}(CUDA.LSTM, 10, 5)
|
||||
randinit(rnn)
|
||||
x, h, c = cu(rand(10)), cu(rand(5)), cu(rand(5))
|
||||
@test collect(test_forward(rnn, x, h, c)[1]) ≈
|
||||
collect(CUDA.forwardInference(rnn, x, h, c)[1])
|
||||
@test collect(test_forward(rnn, x, h, c)[2]) ≈
|
||||
collect(CUDA.forwardInference(rnn, x, h, c)[2])
|
||||
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue