diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index abcd6737..6e2c9e75 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -58,7 +58,7 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray eps = T(1e-5), training = true) where T<:Union{Float32, Float64} dims = _wsize(x) if eps < BATCHNORM_MIN_EPS - warn("eps ",eps," is too small for CuDNN so eps has been assigned the value ", BATCHNORM_MIN_EPS) + # warn("eps ",eps," is too small for CuDNN so eps has been assigned the value ", BATCHNORM_MIN_EPS) eps = BATCHNORM_MIN_EPS end xd = TensorDesc(x) @@ -145,7 +145,7 @@ function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T}, end if eps < BATCHNORM_MIN_EPS - warn("eps ",eps," is too small for CuDNN so eps has been assigned the value ", BATCHNORM_MIN_EPS) + # warn("eps ",eps," is too small for CuDNN so eps has been assigned the value ", BATCHNORM_MIN_EPS) eps = BATCHNORM_MIN_EPS end @@ -187,5 +187,25 @@ batchnorm(g::TrackedArray, b::TrackedArray, x::CuArray{T}, running_mean::CuArray running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} = track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) +batchnorm(g::TrackedArray, b::CuArray{T}, x::TrackedArray, running_mean::CuArray{T}, + running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} = + track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) + +batchnorm(g::CuArray{T}, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T}, + running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} = + track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) + +batchnorm(g::CuArray{T}, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T}, + running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} = + track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) + +batchnorm(g::TrackedArray, b::CuArray{T}, x::CuArray{T}, running_mean::CuArray{T}, + running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} = + track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) + +batchnorm(g::CuArray{T}, b::CuArray{T}, x::TrackedArray, running_mean::CuArray{T}, + running_var::CuArray{T}, momentum; kw...) where T<:Union{Float32, Float64} = + track(batchnorm, g, b, x, running_mean, running_var, momentum; kw...) + @grad batchnorm(g, b, x, running_mean, running_var, momentum; kw...) = batchnorm(data.((g, b, x))..., running_mean, running_var, momentum; kw...), Δ -> (nobacksies(:batchnorm, ∇batchnorm(data.((g, b, x, Δ))..., running_mean, running_var, momentum; kw...))..., nothing, nothing, nothing)