From 4d703b31a1ee458cd2599e7207555aedd8a2ba28 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Thu, 8 Nov 2018 19:23:07 +0530 Subject: [PATCH] Reshape 2D tensors to use cudnn batchnorm --- src/cuda/cudnn.jl | 20 +++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index 04d937d8..94424421 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -39,8 +39,14 @@ end BNCache() = BNCache(nothing, nothing) -# CuDNN supports only 4D and 5D Tensors for BatchNorm Operations -# so use the native julia code when doing batchnorm on a 2D Array +# NOTE: CuDNN supports only 4D and 5D Tensors for BatchNorm Operations +# so reshape a 2D Tensor into 4D +batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T, 2}, + running_mean::CuArray{T}, running_var::CuArray{T}, momentum; + cache = nothing, alpha = T(1), beta = T(0), + eps = T(1e-5), training = true) where T<:Union{Float32, Float64} = + batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), running_mean, running_var, momentum, + cache = cache, alpha = alpha, beta = beta, eps = eps, training = training) function batchnorm(g::CuArray{T}, b::CuArray{T}, x::Union{CuArray{T, 4},CuArray{T,5}}, running_mean::CuArray{T}, running_var::CuArray{T}, momentum; @@ -115,6 +121,14 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray end end +∇batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T, 2}, dy::CuArray{T, 2}, + running_mean::CuArray{T}, running_var::CuArray{T}, momentum; + cache = nothing, eps = T(1e-5), alpha = T(1), + beta = T(0), training = true) where T<:Union{Float32, Float64} = + ∇batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), reshape(dy, 1, 1, size(dy, 1), size(dy, 2)), + running_mean, running_var, momentum, cache = cache, eps = eps, alpha = alpha, beta = beta, + training = training) + function ∇batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T}, dy::CuArray{T}, running_mean::CuArray{T}, running_var::CuArray{T}, momentum; cache = nothing, eps = T(1e-5), alpha = T(1), @@ -176,7 +190,7 @@ end # Flux Interface -(BN::Flux.BatchNorm)(x::Union{CuParam{T,4},CuParam{T,5}}, cache = nothing) where T<:Union{Float32, Float64} = +(BN::Flux.BatchNorm)(x::Union{CuParam{T,2},CuParam{T,4},CuParam{T,5}}, cache = nothing) where T<:Union{Float32, Float64} = batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; cache = cache, alpha = 1, beta = 0, eps = BN.ϵ, training = BN.active) batchnorm(g::TrackedArray, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T},