diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index 6faa8c95..bd0c2198 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -56,11 +56,13 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray alpha = T(1), beta = T(0), eps = T(1e-5), training = true) where T<:Union{Float32, Float64} dims = _wsize(x) - if(eps < BATCHNORM_MIN_EPS) warn("eps ",eps," is too small for CuDNN so eps has been assigned the value ", BATCHNORM_MIN_EPS) eps = BATCHNORM_MIN_EPS end + xd = TensorDesc(x) + yd = TensorDesc(y) + gd = TensorDesc(T, (1,1,length(g),1)) if(training) @@ -82,9 +84,9 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray Cdouble, Ptr{T}, Ptr{T}), libcudnn_handle[], BATCHNORM_SPATIAL, Ref(T(alpha)), Ref(T(beta)), - TensorDesc(x), x, - TensorDesc(y), y, - TensorDesc(g), g, b, + xd, x, + yd, y, + gd, g, b, momentum, running_mean, running_var, eps, mean, ivar) @@ -93,9 +95,8 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray cache.ivar = ivar end else - @check ccall((:cudnnBatchNormalizationForwardInference, libcudnn), cudnnStatus_t, - (cudnnHandle_t,cudnnBatchNormMode_t, + (Ptr{cudnnHandle_t},cudnnBatchNormMode_t, Ptr{T}, Ptr{T}, Ptr{Void}, Ptr{T}, Ptr{Void}, Ptr{T}, @@ -104,9 +105,9 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray Cdouble), libcudnn_handle[], BATCHNORM_SPATIAL, Ref(T(alpha)), Ref(T(beta)), - TensorDesc(x), x, - TensorDesc(y), y, - TensorDesc(g), g, b, + xd, x, + yd, y, + gd, g, b, running_mean, running_var, eps) end