2018-01-30 13:12:33 +00:00
|
|
|
|
using Flux, CuArrays, Base.Test
|
|
|
|
|
using Flux.CUDA
|
2018-01-31 13:46:55 +00:00
|
|
|
|
using Flux.CUDA: RNNDesc
|
|
|
|
|
using CUDAnative
|
2018-01-30 13:12:33 +00:00
|
|
|
|
|
|
|
|
|
info("Testing Flux/CUDNN")
|
|
|
|
|
|
|
|
|
|
function randinit(r::RNNDesc{T}) where T
|
2018-01-31 16:56:27 +00:00
|
|
|
|
for w in (r.weights..., r.bias)
|
2018-01-30 13:12:33 +00:00
|
|
|
|
copy!(w, randn(T, size(w)))
|
|
|
|
|
end
|
|
|
|
|
end
|
|
|
|
|
|
2018-01-31 14:15:57 +00:00
|
|
|
|
const cutanh = CUDAnative.tanh
|
|
|
|
|
|
2018-01-31 16:56:27 +00:00
|
|
|
|
gate(rnn, x, n) = x[(1:rnn.hidden) + rnn.hidden*(n-1)]
|
|
|
|
|
|
2018-01-30 13:12:33 +00:00
|
|
|
|
function test_forward(rnn::RNNDesc, x, h, c = nothing)
|
2018-01-31 13:46:55 +00:00
|
|
|
|
if rnn.mode == CUDA.RNN_RELU
|
2018-01-30 13:12:33 +00:00
|
|
|
|
Wx, Wh = rnn.weights
|
2018-01-31 16:56:27 +00:00
|
|
|
|
b = rnn.bias
|
2018-01-30 13:12:33 +00:00
|
|
|
|
h′ = relu.(Wx'*x .+ Wh'*h .+ b)
|
|
|
|
|
return h′, h′
|
2018-01-31 13:46:55 +00:00
|
|
|
|
elseif rnn.mode == CUDA.GRU
|
2018-01-31 16:56:27 +00:00
|
|
|
|
Wx, Wh = rnn.weights
|
|
|
|
|
b = rnn.bias
|
|
|
|
|
gx, gh = Wx'*x, Wh'*h
|
|
|
|
|
r = σ.(gate(rnn, gx, 1) .+ gate(rnn, gh, 1) .+ gate(rnn, b, 1))
|
|
|
|
|
z = σ.(gate(rnn, gx, 2) .+ gate(rnn, gh, 2) .+ gate(rnn, b, 2))
|
|
|
|
|
h̃ = cutanh.(gate(rnn, gx, 3) .+ r .* gate(rnn, gh, 3) .+ gate(rnn, b, 3))
|
2018-01-31 13:46:55 +00:00
|
|
|
|
h′ = (1.-z).*h̃ .+ z.*h
|
|
|
|
|
return h′, h′
|
2018-01-31 14:15:57 +00:00
|
|
|
|
elseif rnn.mode == CUDA.LSTM
|
2018-01-31 16:56:27 +00:00
|
|
|
|
Wx, Wh = rnn.weights
|
|
|
|
|
b = rnn.bias
|
|
|
|
|
g = Wx'*x .+ Wh'*h .+ b
|
|
|
|
|
input = σ.(gate(rnn, g, 1))
|
|
|
|
|
forget = σ.(gate(rnn, g, 2))
|
|
|
|
|
cell = cutanh.(gate(rnn, g, 3))
|
|
|
|
|
output = σ.(gate(rnn, g, 4))
|
2018-01-31 14:15:57 +00:00
|
|
|
|
c = forget .* c .+ input .* cell
|
|
|
|
|
h = output .* cutanh.(c)
|
|
|
|
|
return (h, h, c)
|
2018-01-30 13:12:33 +00:00
|
|
|
|
end
|
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
@testset "CUDNN" begin
|
|
|
|
|
|
2018-01-31 13:46:55 +00:00
|
|
|
|
rnn = RNNDesc{Float32}(CUDA.RNN_RELU, 10, 5)
|
|
|
|
|
randinit(rnn)
|
|
|
|
|
x, h = cu(rand(10)), cu(rand(5))
|
2018-01-31 16:56:27 +00:00
|
|
|
|
@test collect(test_forward(rnn, x, h)[1]) ≈
|
|
|
|
|
collect(CUDA.forwardInference(rnn, x, h)[1])
|
2018-01-31 13:46:55 +00:00
|
|
|
|
|
|
|
|
|
rnn = RNNDesc{Float32}(CUDA.GRU, 10, 5)
|
2018-01-30 13:12:33 +00:00
|
|
|
|
randinit(rnn)
|
|
|
|
|
x, h = cu(rand(10)), cu(rand(5))
|
2018-01-31 16:56:27 +00:00
|
|
|
|
@test collect(test_forward(rnn, x, h)[1]) ≈
|
|
|
|
|
collect(CUDA.forwardInference(rnn, x, h)[1])
|
2018-01-30 13:12:33 +00:00
|
|
|
|
|
2018-01-31 14:15:57 +00:00
|
|
|
|
rnn = RNNDesc{Float32}(CUDA.LSTM, 10, 5)
|
|
|
|
|
randinit(rnn)
|
|
|
|
|
x, h, c = cu(rand(10)), cu(rand(5)), cu(rand(5))
|
2018-01-31 16:56:27 +00:00
|
|
|
|
@test collect(test_forward(rnn, x, h, c)[1]) ≈
|
|
|
|
|
collect(CUDA.forwardInference(rnn, x, h, c)[1])
|
|
|
|
|
@test collect(test_forward(rnn, x, h, c)[2]) ≈
|
|
|
|
|
collect(CUDA.forwardInference(rnn, x, h, c)[2])
|
2018-01-31 14:15:57 +00:00
|
|
|
|
|
2018-01-30 13:12:33 +00:00
|
|
|
|
end
|