Fix tests
This commit is contained in:
parent
c4f87ff15c
commit
7e7a501efd
|
@ -2,8 +2,6 @@ using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, cudnnTensorDescriptor_t,
|
|||
cudnnBatchNormMode_t, cudnnHandle_t, libcudnn_handle, cudnnDataType, TensorDesc, FilterDesc
|
||||
import ..Flux: data
|
||||
|
||||
using LinearAlgebra
|
||||
|
||||
mutable struct DropoutDesc
|
||||
ptr::Ptr{Nothing}
|
||||
states::CuVector{UInt8}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, cudnnTensorDescriptor_t,
|
||||
cudnnBatchNormMode_t, cudnnHandle_t, libcudnn_handle, cudnnDataType, TensorDesc, FilterDesc
|
||||
|
||||
using LinearAlgebra
|
||||
|
||||
const RNN_RELU = 0 # Stock RNN with ReLu activation
|
||||
const RNN_TANH = 1 # Stock RNN with tanh activation
|
||||
const LSTM = 2 # LSTM with no peephole connections
|
||||
|
|
|
@ -55,11 +55,11 @@ end
|
|||
# .1 * 4 + 0 = .4
|
||||
@test m.μ ≈ reshape([0.3, 0.4], 2, 1)
|
||||
|
||||
# julia> .1 .* std(x, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
|
||||
# julia> .1 .* var(x, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
|
||||
# 2×1 Array{Float64,2}:
|
||||
# 1.14495
|
||||
# 1.14495
|
||||
@test m.σ² ≈ .1 .* std(x.data, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
|
||||
# 1.3
|
||||
# 1.3
|
||||
@test m.σ² ≈ .1 .* var(x.data, dims = 2, corrected=false) .* (3 / 2).+ .9 .* [1., 1.]
|
||||
|
||||
testmode!(m)
|
||||
@test !m.active
|
||||
|
|
Loading…
Reference in New Issue