Fix Tensor Descriptor Bug
This commit is contained in:
parent
c6dcf079ce
commit
af5ab7f9ef
@ -56,11 +56,13 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray
|
|||||||
alpha = T(1), beta = T(0),
|
alpha = T(1), beta = T(0),
|
||||||
eps = T(1e-5), training = true) where T<:Union{Float32, Float64}
|
eps = T(1e-5), training = true) where T<:Union{Float32, Float64}
|
||||||
dims = _wsize(x)
|
dims = _wsize(x)
|
||||||
|
|
||||||
if(eps < BATCHNORM_MIN_EPS)
|
if(eps < BATCHNORM_MIN_EPS)
|
||||||
warn("eps ",eps," is too small for CuDNN so eps has been assigned the value ", BATCHNORM_MIN_EPS)
|
warn("eps ",eps," is too small for CuDNN so eps has been assigned the value ", BATCHNORM_MIN_EPS)
|
||||||
eps = BATCHNORM_MIN_EPS
|
eps = BATCHNORM_MIN_EPS
|
||||||
end
|
end
|
||||||
|
xd = TensorDesc(x)
|
||||||
|
yd = TensorDesc(y)
|
||||||
|
gd = TensorDesc(T, (1,1,length(g),1))
|
||||||
|
|
||||||
if(training)
|
if(training)
|
||||||
|
|
||||||
@ -82,9 +84,9 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray
|
|||||||
Cdouble, Ptr{T}, Ptr{T}),
|
Cdouble, Ptr{T}, Ptr{T}),
|
||||||
libcudnn_handle[], BATCHNORM_SPATIAL,
|
libcudnn_handle[], BATCHNORM_SPATIAL,
|
||||||
Ref(T(alpha)), Ref(T(beta)),
|
Ref(T(alpha)), Ref(T(beta)),
|
||||||
TensorDesc(x), x,
|
xd, x,
|
||||||
TensorDesc(y), y,
|
yd, y,
|
||||||
TensorDesc(g), g, b,
|
gd, g, b,
|
||||||
momentum, running_mean, running_var,
|
momentum, running_mean, running_var,
|
||||||
eps, mean, ivar)
|
eps, mean, ivar)
|
||||||
|
|
||||||
@ -93,9 +95,8 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray
|
|||||||
cache.ivar = ivar
|
cache.ivar = ivar
|
||||||
end
|
end
|
||||||
else
|
else
|
||||||
|
|
||||||
@check ccall((:cudnnBatchNormalizationForwardInference, libcudnn), cudnnStatus_t,
|
@check ccall((:cudnnBatchNormalizationForwardInference, libcudnn), cudnnStatus_t,
|
||||||
(cudnnHandle_t,cudnnBatchNormMode_t,
|
(Ptr{cudnnHandle_t},cudnnBatchNormMode_t,
|
||||||
Ptr{T}, Ptr{T},
|
Ptr{T}, Ptr{T},
|
||||||
Ptr{Void}, Ptr{T},
|
Ptr{Void}, Ptr{T},
|
||||||
Ptr{Void}, Ptr{T},
|
Ptr{Void}, Ptr{T},
|
||||||
@ -104,9 +105,9 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray
|
|||||||
Cdouble),
|
Cdouble),
|
||||||
libcudnn_handle[], BATCHNORM_SPATIAL,
|
libcudnn_handle[], BATCHNORM_SPATIAL,
|
||||||
Ref(T(alpha)), Ref(T(beta)),
|
Ref(T(alpha)), Ref(T(beta)),
|
||||||
TensorDesc(x), x,
|
xd, x,
|
||||||
TensorDesc(y), y,
|
yd, y,
|
||||||
TensorDesc(g), g, b,
|
gd, g, b,
|
||||||
running_mean, running_var,
|
running_mean, running_var,
|
||||||
eps)
|
eps)
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user