Modifying tests in curnn.jl

This commit is contained in:
thebhatman 2019-06-13 18:45:37 +05:30
parent 80c680c598
commit ce6a1bf84f

View File

@ -1,46 +1,46 @@
using Flux, CuArrays, Test using Flux, CuArrays, Test
# @testset "RNN" begin @testset "RNN" begin
# @testset for R in [RNN, GRU, LSTM] @testset for R in [RNN, GRU, LSTM]
# rnn = R(10, 5) rnn = R(10, 5)
# curnn = mapleaves(gpu, rnn) curnn = mapleaves(gpu, rnn)
# @testset for batch_size in (1, 5) @testset for batch_size in (1, 5)
# Flux.reset!(rnn) Flux.reset!(rnn)
# Flux.reset!(curnn) Flux.reset!(curnn)
# x = batch_size == 1 ? x = batch_size == 1 ?
# param(rand(10)) : param(rand(10)) :
# param(rand(10,batch_size)) param(rand(10,batch_size))
# cux = gpu(x) cux = gpu(x)
# y = (rnn(x); rnn(x)) y = (rnn(x); rnn(x))
# cuy = (curnn(cux); curnn(cux)) cuy = (curnn(cux); curnn(cux))
#
# @test y.data ≈ collect(cuy.data) @test y collect(cuy)
# @test haskey(Flux.CUDA.descs, curnn.cell) @test haskey(Flux.CUDA.descs, curnn.cell)
#
# Δ = randn(size(y)) #Δ = randn(size(y))
#
# Flux.back!(y, Δ) #Flux.back!(y, Δ)
# Flux.back!(cuy, gpu(Δ)) #Flux.back!(cuy, gpu(Δ))
#
# @test x.grad ≈ collect(cux.grad) @test x collect(cux)
# @test rnn.cell.Wi.grad ≈ collect(curnn.cell.Wi.grad) @test rnn.cell.Wi collect(curnn.cell.Wi)
# @test rnn.cell.Wh.grad ≈ collect(curnn.cell.Wh.grad) @test rnn.cell.Wh collect(curnn.cell.Wh)
# @test rnn.cell.b.grad ≈ collect(curnn.cell.b.grad) @test rnn.cell.b collect(curnn.cell.b)
# @test rnn.cell.h.grad ≈ collect(curnn.cell.h.grad) @test rnn.cell.h collect(curnn.cell.h)
# if isdefined(rnn.cell, :c) if isdefined(rnn.cell, :c)
# @test rnn.cell.c.grad ≈ collect(curnn.cell.c.grad) @test rnn.cell.c collect(curnn.cell.c)
# end end
#
# Flux.reset!(rnn) Flux.reset!(rnn)
# Flux.reset!(curnn) Flux.reset!(curnn)
# ohx = batch_size == 1 ? ohx = batch_size == 1 ?
# Flux.onehot(rand(1:10), 1:10) : Flux.onehot(rand(1:10), 1:10) :
# Flux.onehotbatch(rand(1:10, batch_size), 1:10) Flux.onehotbatch(rand(1:10, batch_size), 1:10)
# cuohx = gpu(ohx) cuohx = gpu(ohx)
# y = (rnn(ohx); rnn(ohx)) y = (rnn(ohx); rnn(ohx))
# cuy = (curnn(cuohx); curnn(cuohx)) cuy = (curnn(cuohx); curnn(cuohx))
#
# @test y.data ≈ collect(cuy.data) @test y collect(cuy)
# end end
# end end
# end end