gru forward
This commit is contained in:
parent
b1bb05403c
commit
4bfb603da6
|
@ -173,8 +173,8 @@ function (m::GRUCell)(h, x)
|
|||
z = m.update(x′)
|
||||
r = m.reset(x′)
|
||||
h̃ = m.candidate(combine(r.*h, x))
|
||||
h = (1.-z).*h .+ z.*h̃
|
||||
return h, h
|
||||
h′ = (1.-z).*h̃ .+ z.*h
|
||||
return h′, h′
|
||||
end
|
||||
|
||||
hidden(m::GRUCell) = m.h
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
using Flux, CuArrays, Base.Test
|
||||
using Flux.CUDA
|
||||
using Flux.CUDA: RNNDesc, RNN_TANH, RNN_RELU
|
||||
using Flux.CUDA: RNNDesc
|
||||
using CUDAnative
|
||||
|
||||
info("Testing Flux/CUDNN")
|
||||
|
||||
|
@ -14,17 +15,30 @@ function randinit(r::RNNDesc{T}) where T
|
|||
end
|
||||
|
||||
function test_forward(rnn::RNNDesc, x, h, c = nothing)
|
||||
if rnn.mode == RNN_RELU
|
||||
if rnn.mode == CUDA.RNN_RELU
|
||||
Wx, Wh = rnn.weights
|
||||
b, = rnn.biases
|
||||
h′ = relu.(Wx'*x .+ Wh'*h .+ b)
|
||||
return h′, h′
|
||||
elseif rnn.mode == CUDA.GRU
|
||||
Rx, Ux, Cx, Rh, Uh, Ch = rnn.weights
|
||||
bR, bU, bC = rnn.biases
|
||||
r = σ.(Rx'*x .+ Rh'*h .+ bR)
|
||||
z = σ.(Ux'*x .+ Uh'*h .+ bU)
|
||||
h̃ = CUDAnative.tanh.(Cx'*x .+ r .* Ch'*h .+ bC)
|
||||
h′ = (1.-z).*h̃ .+ z.*h
|
||||
return h′, h′
|
||||
end
|
||||
end
|
||||
|
||||
@testset "CUDNN" begin
|
||||
|
||||
rnn = RNNDesc{Float32}(RNN_RELU, 10, 5)
|
||||
rnn = RNNDesc{Float32}(CUDA.RNN_RELU, 10, 5)
|
||||
randinit(rnn)
|
||||
x, h = cu(rand(10)), cu(rand(5))
|
||||
@test collect(test_forward(rnn, x, h)[1]) ≈ collect(CUDA.forwardInference(rnn, x, h)[1])
|
||||
|
||||
rnn = RNNDesc{Float32}(CUDA.GRU, 10, 5)
|
||||
randinit(rnn)
|
||||
x, h = cu(rand(10)), cu(rand(5))
|
||||
@test collect(test_forward(rnn, x, h)[1]) ≈ collect(CUDA.forwardInference(rnn, x, h)[1])
|
||||
|
|
Loading…
Reference in New Issue