gru forward

This commit is contained in:
Mike J Innes 2018-01-31 13:46:55 +00:00
parent b1bb05403c
commit 4bfb603da6
2 changed files with 19 additions and 5 deletions

View File

@ -173,8 +173,8 @@ function (m::GRUCell)(h, x)
z = m.update(x)
r = m.reset(x)
= m.candidate(combine(r.*h, x))
h = (1.-z).*h .+ z.*h̃
return h, h
h = (1.-z).*h̃ .+ z.*h
return h, h
end
hidden(m::GRUCell) = m.h

View File

@ -1,6 +1,7 @@
using Flux, CuArrays, Base.Test
using Flux.CUDA
using Flux.CUDA: RNNDesc, RNN_TANH, RNN_RELU
using Flux.CUDA: RNNDesc
using CUDAnative
info("Testing Flux/CUDNN")
@ -14,17 +15,30 @@ function randinit(r::RNNDesc{T}) where T
end
function test_forward(rnn::RNNDesc, x, h, c = nothing)
if rnn.mode == RNN_RELU
if rnn.mode == CUDA.RNN_RELU
Wx, Wh = rnn.weights
b, = rnn.biases
h = relu.(Wx'*x .+ Wh'*h .+ b)
return h, h
elseif rnn.mode == CUDA.GRU
Rx, Ux, Cx, Rh, Uh, Ch = rnn.weights
bR, bU, bC = rnn.biases
r = σ.(Rx'*x .+ Rh'*h .+ bR)
z = σ.(Ux'*x .+ Uh'*h .+ bU)
= CUDAnative.tanh.(Cx'*x .+ r .* Ch'*h .+ bC)
h = (1.-z).* .+ z.*h
return h, h
end
end
@testset "CUDNN" begin
rnn = RNNDesc{Float32}(RNN_RELU, 10, 5)
rnn = RNNDesc{Float32}(CUDA.RNN_RELU, 10, 5)
randinit(rnn)
x, h = cu(rand(10)), cu(rand(5))
@test collect(test_forward(rnn, x, h)[1]) collect(CUDA.forwardInference(rnn, x, h)[1])
rnn = RNNDesc{Float32}(CUDA.GRU, 10, 5)
randinit(rnn)
x, h = cu(rand(10)), cu(rand(5))
@test collect(test_forward(rnn, x, h)[1]) collect(CUDA.forwardInference(rnn, x, h)[1])