Flux.jl/test/cuda/cudnn.jl

47 lines
1.1 KiB
Julia
Raw Normal View History

2018-01-30 13:12:33 +00:00
using Flux, CuArrays, Base.Test
using Flux.CUDA
2018-01-31 13:46:55 +00:00
using Flux.CUDA: RNNDesc
using CUDAnative
2018-01-30 13:12:33 +00:00
info("Testing Flux/CUDNN")
function randinit(r::RNNDesc{T}) where T
for w in r.weights
copy!(w, randn(T, size(w)))
end
for w in r.biases
copy!(w, randn(T, size(w)))
end
end
function test_forward(rnn::RNNDesc, x, h, c = nothing)
2018-01-31 13:46:55 +00:00
if rnn.mode == CUDA.RNN_RELU
2018-01-30 13:12:33 +00:00
Wx, Wh = rnn.weights
b, = rnn.biases
h = relu.(Wx'*x .+ Wh'*h .+ b)
return h, h
2018-01-31 13:46:55 +00:00
elseif rnn.mode == CUDA.GRU
Rx, Ux, Cx, Rh, Uh, Ch = rnn.weights
bR, bU, bC = rnn.biases
r = σ.(Rx'*x .+ Rh'*h .+ bR)
z = σ.(Ux'*x .+ Uh'*h .+ bU)
= CUDAnative.tanh.(Cx'*x .+ r .* Ch'*h .+ bC)
h = (1.-z).* .+ z.*h
return h, h
2018-01-30 13:12:33 +00:00
end
end
@testset "CUDNN" begin
2018-01-31 13:46:55 +00:00
rnn = RNNDesc{Float32}(CUDA.RNN_RELU, 10, 5)
randinit(rnn)
x, h = cu(rand(10)), cu(rand(5))
@test collect(test_forward(rnn, x, h)[1]) collect(CUDA.forwardInference(rnn, x, h)[1])
rnn = RNNDesc{Float32}(CUDA.GRU, 10, 5)
2018-01-30 13:12:33 +00:00
randinit(rnn)
x, h = cu(rand(10)), cu(rand(5))
@test collect(test_forward(rnn, x, h)[1]) collect(CUDA.forwardInference(rnn, x, h)[1])
end