LSTM
This commit is contained in:
parent
4bfb603da6
commit
8ad837bb70
@ -14,6 +14,8 @@ function randinit(r::RNNDesc{T}) where T
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
const cutanh = CUDAnative.tanh
|
||||||
|
|
||||||
function test_forward(rnn::RNNDesc, x, h, c = nothing)
|
function test_forward(rnn::RNNDesc, x, h, c = nothing)
|
||||||
if rnn.mode == CUDA.RNN_RELU
|
if rnn.mode == CUDA.RNN_RELU
|
||||||
Wx, Wh = rnn.weights
|
Wx, Wh = rnn.weights
|
||||||
@ -25,9 +27,19 @@ function test_forward(rnn::RNNDesc, x, h, c = nothing)
|
|||||||
bR, bU, bC = rnn.biases
|
bR, bU, bC = rnn.biases
|
||||||
r = σ.(Rx'*x .+ Rh'*h .+ bR)
|
r = σ.(Rx'*x .+ Rh'*h .+ bR)
|
||||||
z = σ.(Ux'*x .+ Uh'*h .+ bU)
|
z = σ.(Ux'*x .+ Uh'*h .+ bU)
|
||||||
h̃ = CUDAnative.tanh.(Cx'*x .+ r .* Ch'*h .+ bC)
|
h̃ = cutanh.(Cx'*x .+ r .* Ch'*h .+ bC)
|
||||||
h′ = (1.-z).*h̃ .+ z.*h
|
h′ = (1.-z).*h̃ .+ z.*h
|
||||||
return h′, h′
|
return h′, h′
|
||||||
|
elseif rnn.mode == CUDA.LSTM
|
||||||
|
Ix, Fx, Cx, Ox, Ih, Fh, Ch, Oh = rnn.weights
|
||||||
|
bI, bF, bC, bO = rnn.biases
|
||||||
|
input = σ.(Ix'*x .+ Ih'*h .+ bI)
|
||||||
|
forget = σ.(Fx'*x .+ Fh'*h .+ bF)
|
||||||
|
cell = cutanh.(Cx'*x .+ Ch'*h .+ bC)
|
||||||
|
output = σ.(Ox'*x .+ Oh'*h .+ bO)
|
||||||
|
c = forget .* c .+ input .* cell
|
||||||
|
h = output .* cutanh.(c)
|
||||||
|
return (h, h, c)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -43,4 +55,10 @@ randinit(rnn)
|
|||||||
x, h = cu(rand(10)), cu(rand(5))
|
x, h = cu(rand(10)), cu(rand(5))
|
||||||
@test collect(test_forward(rnn, x, h)[1]) ≈ collect(CUDA.forwardInference(rnn, x, h)[1])
|
@test collect(test_forward(rnn, x, h)[1]) ≈ collect(CUDA.forwardInference(rnn, x, h)[1])
|
||||||
|
|
||||||
|
rnn = RNNDesc{Float32}(CUDA.LSTM, 10, 5)
|
||||||
|
randinit(rnn)
|
||||||
|
x, h, c = cu(rand(10)), cu(rand(5)), cu(rand(5))
|
||||||
|
@test collect(test_forward(rnn, x, h, c)[1]) ≈ collect(CUDA.forwardInference(rnn, x, h, c)[1])
|
||||||
|
@test collect(test_forward(rnn, x, h, c)[2]) ≈ collect(CUDA.forwardInference(rnn, x, h, c)[2])
|
||||||
|
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user