diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index 9948ef37..dd1775ad 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -27,15 +27,16 @@ const BATCHNORM_SPATIAL = 1 const BATCHNORM_ACTIVATION = 0 const BATCHNORM_MIN_EPS = 1e-5 -@inline _wsize(y) = ((1 for _=1:ndims(y)-2)..., size(y)[end-1], 1) -@inline _reddims(y) = ((i for i=1:ndims(y)-2)..., ndims(y)) +@inline _wsize(y) = (map(_ -> 1, size(y)[1:end-2])..., size(y)[end-1], 1) -mutable struct bncache +@inline _reddims(y) = (collect(1:ndims(y)-2)..., ndims(y)) + +mutable struct BNCache mean ivar end -bncache() = bncache(nothing, nothing) +BNCache() = BNCache(nothing, nothing) # CuDNN supports only 4D and 5D Tensors for BatchNorm Operations # so use the native julia code when doing batchnorm on a 2D Array @@ -56,7 +57,7 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray alpha = T(1), beta = T(0), eps = T(1e-5), training = true) where T<:Union{Float32, Float64} dims = _wsize(x) - if(eps < BATCHNORM_MIN_EPS) + if eps < BATCHNORM_MIN_EPS warn("eps ",eps," is too small for CuDNN so eps has been assigned the value ", BATCHNORM_MIN_EPS) eps = BATCHNORM_MIN_EPS end @@ -64,11 +65,11 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray yd = TensorDesc(y) gd = TensorDesc(T, (1,1,length(g),1)) - if(training) + if training - if(cache !== nothing) - mean = cu(zeros(T, dims...)) - ivar = cu(ones(T, dims...)) + if cache !== nothing + mean = zeros(CuArray{T}, dims...) + ivar = ones(CuArray{T}, dims...) else mean = C_NULL ivar = C_NULL @@ -90,7 +91,7 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray momentum, running_mean, running_var, eps, mean, ivar) - if(cache !== nothing) + if cache !== nothing cache.mean = mean cache.ivar = ivar end @@ -131,7 +132,7 @@ function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T}, momentum; cache = nothing, eps = T(1e-5), alpha = T(1), beta = T(0), dalpha = T(1), dbeta = T(0), training = true) where T<:Union{Float32, Float64} - if(training) + if training xd = TensorDesc(x) dyd = TensorDesc(dy) dxd = TensorDesc(dx) @@ -143,7 +144,7 @@ function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T}, mean, ivar = C_NULL, C_NULL end - if(eps < BATCHNORM_MIN_EPS) + if eps < BATCHNORM_MIN_EPS warn("eps ",eps," is too small for CuDNN so eps has been assigned the value ", BATCHNORM_MIN_EPS) eps = BATCHNORM_MIN_EPS end @@ -175,7 +176,8 @@ end # Flux Interface -import Flux.Tracker: track, back, @back, istracked +import ..Flux: Flux +import ..Tracker: track, back, @back, istracked, TrackedArray _batchnorm(g, b, x, running_mean, running_var, momentum, cache, alpha, beta, eps, training) = @@ -195,7 +197,7 @@ batchnorm(g::TrackedArray, b::TrackedArray, x::CuArray{T}, running_mean::CuArray function back(::typeof(_batchnorm), Δ, g, b, x, running_mean, running_var, momentum, cache, alpha, beta, eps, training) deriv_tup = ∇batchnorm(data(g), data(b), data(x), Δ, running_mean, running_var, momentum, cache = cache, alpha = alpha, beta = beta, eps = eps, training = training) - istracked(x) && @back(x, deriv_tup[1]) + @back(x, deriv_tup[1]) @back(b, deriv_tup[2]) @back(g, deriv_tup[3]) end