From f12e367cab310a4966a7e7f22fc67f26547f2069 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 12 Jun 2018 18:26:09 +0530 Subject: [PATCH] Adding untested backward pass code --- src/cuda/cudnn.jl | 90 +++++++++++++++++++++++++++++++++++++---------- 1 file changed, 72 insertions(+), 18 deletions(-) diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index 1ce16d55..fd0dd7a6 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -47,15 +47,25 @@ end bncache() = bncache(nothing, nothing) -(CuBN::CuBatchNorm)(x::CuArray{T}) where T<:Union{Float32, Float64} = - CuBN.λ.(cudnnBatchNormalizationForward(CuBN.γ, CuBN.β, x, CuBN.μ, CuBN.σ, CuBN.momentum, eps = CuBN.ϵ, training = CuBN.active)) +(CuBN::CuBatchNorm)(x::CuArray{T}; cache = nothing) where T<:Union{Float32, Float64} = + CuBN.λ.(cudnnBNForward(CuBN.γ, CuBN.β, x, CuBN.μ, CuBN.σ, CuBN.momentum, cache = cache, eps = CuBN.ϵ, training = CuBN.active)) -function cudnnBatchNormalizationForward(g::CuArray{T}, b::CuArray{T}, x::CuArray{T}, - running_mean::CuArray{T}, running_var::CuArray{T}, - momentum::T; cache = nothing, - alpha = T(1), beta = T(0), - eps = T(1e-5), training = true) where T<:Union{Float32, Float64} - y = similar(x) +function cudnnBNForward(g::CuArray{T}, b::CuArray{T}, x::CuArray{T}, + running_mean::CuArray{T}, running_var::CuArray{T}, + momentum::T; cache = nothing, + alpha = T(1), beta = T(0), + eps = T(1e-5), training = true) where T<:Union{Float32, Float64} + y = similar(x) + cudnnBNForward!(y, g, b, x, running_mean, running_var, momentum, cache = cache + alpha = alpha, beta = beta, eps = eps, training = training) + y +end + +function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray{T}, + running_mean::CuArray{T}, running_var::CuArray{T}, + momentum::T; cache = nothing, + alpha = T(1), beta = T(0), + eps = T(1e-5), training = true) where T<:Union{Float32, Float64} dims = _wsize(x) if(eps < BATCHNORM_MIN_EPS) @@ -74,11 +84,13 @@ function cudnnBatchNormalizationForward(g::CuArray{T}, b::CuArray{T}, x::CuArray end @check ccall((:cudnnBatchNormalizationForwardTraining, libcudnn), cudnnStatus_t, - (cudnnHandle_t,cudnnBatchNormMode_t,Ptr{Void}, Ptr{Void}, - Ptr{Void},Ptr{Void},Ptr{Void},Ptr{Void}, - Ptr{Void},Ptr{Void},Ptr{Void}, - Cdouble,Ptr{Void},Ptr{Void}, - Cdouble,Ptr{Void},Ptr{Void}), + (cudnnHandle_t,cudnnBatchNormMode_t, + Ptr{T}, Ptr{T}, + Ptr{Void}, Ptr{T}, + Ptr{Void}, Ptr{T}, + Ptr{Void}, Ptr{T}, Ptr{T}, + Cdouble, Ptr{T}, Ptr{T}, + Cdouble, Ptr{T}, Ptr{T}), libcudnn_handle[], BATCHNORM_SPATIAL, Ref(T(alpha)), Ref(T(beta)), TensorDesc(x), x, @@ -94,10 +106,12 @@ function cudnnBatchNormalizationForward(g::CuArray{T}, b::CuArray{T}, x::CuArray else @check ccall((:cudnnBatchNormalizationForwardInference, libcudnn), cudnnStatus_t, - (cudnnHandle_t,cudnnBatchNormMode_t,Ptr{Void}, Ptr{Void}, - Ptr{Void},Ptr{Void},Ptr{Void},Ptr{Void}, - Ptr{Void},Ptr{Void},Ptr{Void}, - Ptr{Void},Ptr{Void}, + (cudnnHandle_t,cudnnBatchNormMode_t, + Ptr{T}, Ptr{T}, + Ptr{Void}, Ptr{T}, + Ptr{Void}, Ptr{T}, + Ptr{Void}, Ptr{T}, Ptr{T}, + Ptr{T}, Ptr{T}, Cdouble), libcudnn_handle[], BATCHNORM_SPATIAL, Ref(T(alpha)), Ref(T(beta)), @@ -107,7 +121,47 @@ function cudnnBatchNormalizationForward(g::CuArray{T}, b::CuArray{T}, x::CuArray running_mean, running_var, eps) end - y +end + +function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T}, + dx::CuArray{T}, x::CuArray{T}, dy::CuArray{T}, + running_mean::CuArray{T}, running_var::CuArray{T} + momentum; training = true, + cache = nothing, eps = T(1e-5), + alpha = T(1), beta = T(0), + dalpha = T(1), dbeta = T(0)) where T<:Union{Float32, Float64} + if(training) + + if cache !== nothing + mean, ivar = cache.mean, cache.ivar + cache_verbose && info("mean and ivar are fetched from the cache") + else + mean, ivar = C_NULL, C_NULL + end + + @check ccall((:cudnnBatchNormalizationBackward, libcudnn), cudnnStatus_t, + (cudnnHandle_t,cudnnBatchNormMode_t, + Ptr{T}, Ptr{T}, + Ptr{T}, Ptr{T}, + Ptr{Void}, Ptr{T}, + Ptr{Void}, Ptr{T}, + Ptr{Void}, Ptr{T}, + Ptr{Void}, Ptr{T}, Ptr{T}, Ptr{T}, + Cdouble, Ptr{T}, Ptr{T}), + libcudnn_handle[], BATCHNORM_SPATIAL, + Ref(T(alpha)), Ref(T(beta)), + Ref(T(dalpha)), Ref(T(dbeta)), + TensorDesc(x), x, + TensorDesc(dy), dy, + TensorDesc(dx), dx, + TensorDesc(g), g, dg, db, + eps, mean, ivar) + else + ivar = 1 ./ sqrt.(running_var .+ eps) + dx = dy .* g .* ivar + dg = sum(dy .* (x .- running_mean) .* ivar, _reddims(dy)) + db = sum(dy, _reddims(dy)) + end end const RNN_RELU = 0 # Stock RNN with ReLu activation