diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index c8dc553a..132e105f 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -73,20 +73,20 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray end @check ccall((:cudnnBatchNormalizationForwardTraining, libcudnn), cudnnStatus_t, - (cudnnHandle_t,cudnnBatchNormMode_t, - Ptr{T}, Ptr{T}, - Ptr{Void}, Ptr{T}, - Ptr{Void}, Ptr{T}, - Ptr{Void}, Ptr{T}, Ptr{T}, - Cdouble, Ptr{T}, Ptr{T}, - Cdouble, Ptr{T}, Ptr{T}), - libcudnn_handle[], BATCHNORM_SPATIAL, - Ref(T(alpha)), Ref(T(beta)), - xd, x, - yd, y, - gd, g, b, - momentum, running_mean, running_var, - eps, mean, ivar) + (cudnnHandle_t,cudnnBatchNormMode_t, + Ptr{T}, Ptr{T}, + Ptr{Void}, Ptr{T}, + Ptr{Void}, Ptr{T}, + Ptr{Void}, Ptr{T}, Ptr{T}, + Cdouble, Ptr{T}, Ptr{T}, + Cdouble, Ptr{T}, Ptr{T}), + libcudnn_handle[], BATCHNORM_SPATIAL, + Ref(T(alpha)), Ref(T(beta)), + xd, x, + yd, y, + gd, g, b, + momentum, running_mean, running_var, + eps, mean, ivar) if(cache !== nothing) cache.mean = mean @@ -94,60 +94,78 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray end else @check ccall((:cudnnBatchNormalizationForwardInference, libcudnn), cudnnStatus_t, - (Ptr{cudnnHandle_t},cudnnBatchNormMode_t, - Ptr{T}, Ptr{T}, - Ptr{Void}, Ptr{T}, - Ptr{Void}, Ptr{T}, - Ptr{Void}, Ptr{T}, Ptr{T}, - Ptr{T}, Ptr{T}, - Cdouble), - libcudnn_handle[], BATCHNORM_SPATIAL, - Ref(T(alpha)), Ref(T(beta)), - xd, x, - yd, y, - gd, g, b, - running_mean, running_var, - eps) + (Ptr{cudnnHandle_t},cudnnBatchNormMode_t, + Ptr{T}, Ptr{T}, + Ptr{Void}, Ptr{T}, + Ptr{Void}, Ptr{T}, + Ptr{Void}, Ptr{T}, Ptr{T}, + Ptr{T}, Ptr{T}, + Cdouble), + libcudnn_handle[], BATCHNORM_SPATIAL, + Ref(T(alpha)), Ref(T(beta)), + xd, x, + yd, y, + gd, g, b, + running_mean, running_var, + eps) end end -function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T}, - dx::CuArray{T}, x::CuArray{T}, dy::CuArray{T}, - running_mean::CuArray{T}, running_var::CuArray{T}, - momentum; training = true, - cache = nothing, eps = T(1e-5), - alpha = T(1), beta = T(0), - dalpha = T(1), dbeta = T(0)) where T<:Union{Float32, Float64} - if(training) - - if cache !== nothing - mean, ivar = cache.mean, cache.ivar - cache_verbose && info("mean and ivar are fetched from the cache") - else - mean, ivar = C_NULL, C_NULL - end - - @check ccall((:cudnnBatchNormalizationBackward, libcudnn), cudnnStatus_t, - (cudnnHandle_t,cudnnBatchNormMode_t, - Ptr{T}, Ptr{T}, - Ptr{T}, Ptr{T}, - Ptr{Void}, Ptr{T}, - Ptr{Void}, Ptr{T}, - Ptr{Void}, Ptr{T}, - Ptr{Void}, Ptr{T}, Ptr{T}, Ptr{T}, - Cdouble, Ptr{T}, Ptr{T}), - libcudnn_handle[], BATCHNORM_SPATIAL, - Ref(T(alpha)), Ref(T(beta)), - Ref(T(dalpha)), Ref(T(dbeta)), - TensorDesc(x), x, - TensorDesc(dy), dy, - TensorDesc(dx), dx, - TensorDesc(g), g, dg, db, - eps, mean, ivar) - else - ivar = 1 ./ sqrt.(running_var .+ eps) - dx = dy .* g .* ivar - dg = sum(dy .* (x .- running_mean) .* ivar, _reddims(dy)) - db = sum(dy, _reddims(dy)) - end +function cudnnBNBackward(g, b, x::CuArray{T}, dy::CuArray{T}, running_mean::CuArray{T}, + running_var::CuArray{T}, momentum; + training = true, cache = nothing, eps = T(1e-5), + alpha = T(1), beta = T(0)) where T<:Union{Float32, Float64} + dx = similar(x) + cudnnBNBackward!(g.grad, data(g), b.grad, dx, x, dy, running_mean, running_var, T(momentum), + training = training, cache = cache, eps = eps, alpha = alpha, beta = beta) + dx +end + +function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T}, + dx::CuArray{T}, x::CuArray{T}, dy::CuArray{T}, + running_mean::CuArray{T}, running_var::CuArray{T}, + momentum; training = true, + cache = nothing, eps = T(1e-5), + alpha = T(1), beta = T(0), + dalpha = T(1), dbeta = T(0)) where T<:Union{Float32, Float64} + if(training) + xd = TensorDesc(x) + dyd = TensorDesc(dy) + dxd = TensorDesc(dx) + gd = TensorDesc(T, (1,1,length(g),1)) + if cache !== nothing + mean, ivar = cache.mean, cache.ivar + info("mean and ivar are fetched from the cache") + else + mean, ivar = C_NULL, C_NULL + end + + if(eps < BATCHNORM_MIN_EPS) + warn("eps ",eps," is too small for CuDNN so eps has been assigned the value ", BATCHNORM_MIN_EPS) + eps = BATCHNORM_MIN_EPS + end + + @check ccall((:cudnnBatchNormalizationBackward, libcudnn), cudnnStatus_t, + (cudnnHandle_t,cudnnBatchNormMode_t, + Ptr{T}, Ptr{T}, + Ptr{T}, Ptr{T}, + Ptr{Void}, Ptr{T}, + Ptr{Void}, Ptr{T}, + Ptr{Void}, Ptr{T}, + Ptr{Void}, Ptr{T}, Ptr{T}, Ptr{T}, + Cdouble, Ptr{T}, Ptr{T}), + libcudnn_handle[], BATCHNORM_SPATIAL, + Ref(T(alpha)), Ref(T(beta)), + Ref(T(dalpha)), Ref(T(dbeta)), + xd, x, + dyd, dy, + dxd, dx, + gd, g, dg, db, + eps, mean, ivar) + else + ivar = 1 ./ sqrt.(reshape(running_var, (1, 1, length(running_var), 1)) .+ eps) + dx .= dy .* reshape(g, (1, 1, length(g), 1)) .* ivar + dg .= squeeze(sum(dy .* (x .- reshape(running_mean, (1, 1, length(running_mean), 1))) .* ivar, _reddims(dy)), (1,2,4)) + db .= squeeze(sum(dy, _reddims(dy)), (1,2,4)) + end end