Fix tests

This commit is contained in:
Avik Pal 2018-09-11 16:32:14 +05:30
parent c4f87ff15c
commit 7e7a501efd
3 changed files with 6 additions and 6 deletions

View File

@ -2,8 +2,6 @@ using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, cudnnTensorDescriptor_t,
cudnnBatchNormMode_t, cudnnHandle_t, libcudnn_handle, cudnnDataType, TensorDesc, FilterDesc
import ..Flux: data
using LinearAlgebra
mutable struct DropoutDesc
ptr::Ptr{Nothing}
states::CuVector{UInt8}

View File

@ -1,6 +1,8 @@
using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, cudnnTensorDescriptor_t,
cudnnBatchNormMode_t, cudnnHandle_t, libcudnn_handle, cudnnDataType, TensorDesc, FilterDesc
using LinearAlgebra
const RNN_RELU = 0 # Stock RNN with ReLu activation
const RNN_TANH = 1 # Stock RNN with tanh activation
const LSTM = 2 # LSTM with no peephole connections

View File

@ -55,11 +55,11 @@ end
# .1 * 4 + 0 = .4
@test m.μ reshape([0.3, 0.4], 2, 1)
# julia> .1 .* std(x, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
# julia> .1 .* var(x, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
# 2×1 Array{Float64,2}:
# 1.14495
# 1.14495
@test m.σ² .1 .* std(x.data, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
# 1.3
# 1.3
@test m.σ² .1 .* var(x.data, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
testmode!(m)
@test !m.active