From d4b066fdf96bd503c7c2bb9bd29bed2b0ab8787a Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 12 Jun 2018 17:49:21 +0530 Subject: [PATCH] Forward Pass for BatchNorm Added --- src/cuda/cudnn.jl | 92 +++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 90 insertions(+), 2 deletions(-) diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index bcadcf4f..d517024e 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -1,5 +1,7 @@ -using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, libcudnn_handle, - cudnnDataType, TensorDesc, FilterDesc +using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, cudnnTensorDescriptor_t, + cudnnBatchNormMode_t, cudnnHandle_t, libcudnn_handle, cudnnDataType, TensorDesc, FilterDesc +using CuArrays +using Flux mutable struct DropoutDesc ptr::Ptr{Void} @@ -22,6 +24,92 @@ function DropoutDesc(ρ::Real; seed::Integer=0) return desc end +CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}} +CuBatchNorm{T} = Flux.BatchNorm{<:Union{typeof(identity),typeof(relu)}, + <:CuParam{T,1},<:CuArray{T,1}, + <:Union{Float32,Float64}} + +CuBatchNorm(chs::Integer, λ = identity; + initβ = zeros, initγ = ones, ϵ = 1e-8, momentum = .1) = + BatchNorm(λ, param(cu(initβ(Float32,chs))), param(cu(initγ(Float32,chs))), + zeros(Float32,chs), ones(Float32,chs), ϵ, momentum, true) + +const BATCHNORM_SPATIAL = 1 +const BATCHNORM_ACTIVATION = 0 +const BATCHNORM_MIN_EPS = 1e-5 + +@inline _wsize(y) = ((1 for _=1:ndims(y)-2)..., size(y)[end-1], 1) + +mutable struct bncache + mean + ivar +end + +bncache() = bncache(nothing, nothing) + +(CuBN::CuBatchNorm)(x::CuArray{T}) where T<:Union{Float32, Float64} = + CuBN.λ.(cudnnBatchNormalizationForward(CuBN.γ, CuBN.β, x, CuBN.μ, CuBN.σ, CuBN.momentum, eps = CuBN.ϵ, training = CuBN.active)) + +function cudnnBatchNormalizationForward(g::CuArray{T}, b::CuArray{T}, x::CuArray{T}, + running_mean::CuArray{T}, running_var::CuArray{T}, + momentum::T; cache = nothing, + alpha = T(1), beta = T(0), + eps = T(1e-5), training = true) where T<:Union{Float32, Float64} + y = similar(x) + dims = _wsize(x) + + if(eps < BATCHNORM_MIN_EPS) + warn("eps ",eps," is too small for CuDNN so eps has been assigned the value ", BATCHNORM_MIN_EPS) + eps = BATCHNORM_MIN_EPS + end + + if(training) + + if(cache !== nothing) + mean = cu(zeros(T, dims...)) + ivar = cu(ones(T, dims...)) + else + mean = C_NULL + ivar = C_NULL + end + + @check ccall((:cudnnBatchNormalizationForwardTraining, libcudnn), cudnnStatus_t, + (cudnnHandle_t,cudnnBatchNormMode_t,Ptr{Void}, Ptr{Void}, + Ptr{Void},Ptr{Void},Ptr{Void},Ptr{Void}, + Ptr{Void},Ptr{Void},Ptr{Void}, + Cdouble,Ptr{Void},Ptr{Void}, + Cdouble,Ptr{Void},Ptr{Void}), + libcudnn_handle[], BATCHNORM_SPATIAL, + Ref(T(alpha)), Ref(T(beta)), + TensorDesc(x), x, + TensorDesc(y), y, + TensorDesc(g), g, b, + momentum, running_mean, running_var, + eps, mean, ivar) + + if(cache !== nothing) + cache.mean = mean + cache.invvar = ivar + end + else + + @check ccall((:cudnnBatchNormalizationForwardInference, libcudnn), cudnnStatus_t, + (cudnnHandle_t,cudnnBatchNormMode_t,Ptr{Void}, Ptr{Void}, + Ptr{Void},Ptr{Void},Ptr{Void},Ptr{Void}, + Ptr{Void},Ptr{Void},Ptr{Void}, + Ptr{Void},Ptr{Void}, + Cdouble), + libcudnn_handle[], BATCHNORM_SPATIAL, + Ref(T(alpha)), Ref(T(beta)), + TensorDesc(x), x, + TensorDesc(y), y, + TensorDesc(g), g, b, + running_mean, running_var, + eps) + end + y +end + const RNN_RELU = 0 # Stock RNN with ReLu activation const RNN_TANH = 1 # Stock RNN with tanh activation const LSTM = 2 # LSTM with no peephole connections