64 lines
1.5 KiB
Julia
64 lines
1.5 KiB
Julia
using Flux, CuArrays, Test
|
|
using Flux: pullback
|
|
|
|
@testset for R in [RNN, GRU, LSTM]
|
|
m = R(10, 5) |> gpu
|
|
x = gpu(rand(10))
|
|
(m̄,) = gradient(m -> sum(m(x)), m)
|
|
Flux.reset!(m)
|
|
θ = gradient(() -> sum(m(x)), params(m))
|
|
@test collect(m̄[].cell[].Wi) == collect(θ[m.cell.Wi])
|
|
end
|
|
|
|
@testset "RNN" begin
|
|
@testset for R in [RNN, GRU, LSTM], batch_size in (1, 5)
|
|
rnn = R(10, 5)
|
|
curnn = fmap(gpu, rnn)
|
|
|
|
Flux.reset!(rnn)
|
|
Flux.reset!(curnn)
|
|
x = batch_size == 1 ?
|
|
rand(10) :
|
|
rand(10, batch_size)
|
|
cux = gpu(x)
|
|
|
|
y, back = pullback((r, x) -> r(x), rnn, x)
|
|
cuy, cuback = pullback((r, x) -> r(x), curnn, cux)
|
|
|
|
@test y ≈ collect(cuy)
|
|
@test haskey(Flux.CUDA.descs, curnn.cell)
|
|
|
|
ȳ = randn(size(y))
|
|
m̄, x̄ = back(ȳ)
|
|
cum̄, cux̄ = cuback(gpu(ȳ))
|
|
|
|
m̄[].cell[].Wi
|
|
|
|
m̄[].state
|
|
cum̄[].state
|
|
|
|
@test x̄ ≈ collect(cux̄)
|
|
@test m̄[].cell[].Wi ≈ collect(cum̄[].cell[].Wi)
|
|
@test m̄[].cell[].Wh ≈ collect(cum̄[].cell[].Wh)
|
|
@test m̄[].cell[].b ≈ collect(cum̄[].cell[].b)
|
|
if m̄[].state isa Tuple
|
|
for (x, cx) in zip(m̄[].state, cum̄[].state)
|
|
@test x ≈ collect(cx)
|
|
end
|
|
else
|
|
@test m̄[].state ≈ collect(cum̄[].state)
|
|
end
|
|
|
|
Flux.reset!(rnn)
|
|
Flux.reset!(curnn)
|
|
ohx = batch_size == 1 ?
|
|
Flux.onehot(rand(1:10), 1:10) :
|
|
Flux.onehotbatch(rand(1:10, batch_size), 1:10)
|
|
cuohx = gpu(ohx)
|
|
y = (rnn(ohx); rnn(ohx))
|
|
cuy = (curnn(cuohx); curnn(cuohx))
|
|
|
|
@test y ≈ collect(cuy)
|
|
end
|
|
end
|