From 24d13ac3262e6488227a63e93f971800d2fc756e Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 12 Jun 2018 21:32:56 +0530 Subject: [PATCH] Fix missing parenthesis --- src/cuda/cudnn.jl | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index fd0dd7a6..2c2be1d6 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -41,24 +41,24 @@ const BATCHNORM_MIN_EPS = 1e-5 @inline _wsize(y) = ((1 for _=1:ndims(y)-2)..., size(y)[end-1], 1) mutable struct bncache - mean - ivar + mean + ivar end bncache() = bncache(nothing, nothing) (CuBN::CuBatchNorm)(x::CuArray{T}; cache = nothing) where T<:Union{Float32, Float64} = - CuBN.λ.(cudnnBNForward(CuBN.γ, CuBN.β, x, CuBN.μ, CuBN.σ, CuBN.momentum, cache = cache, eps = CuBN.ϵ, training = CuBN.active)) + CuBN.λ.(cudnnBNForward(CuBN.γ, CuBN.β, x, CuBN.μ, CuBN.σ, CuBN.momentum, cache = cache, eps = CuBN.ϵ, training = CuBN.active)) function cudnnBNForward(g::CuArray{T}, b::CuArray{T}, x::CuArray{T}, running_mean::CuArray{T}, running_var::CuArray{T}, momentum::T; cache = nothing, alpha = T(1), beta = T(0), eps = T(1e-5), training = true) where T<:Union{Float32, Float64} - y = similar(x) - cudnnBNForward!(y, g, b, x, running_mean, running_var, momentum, cache = cache - alpha = alpha, beta = beta, eps = eps, training = training) - y + y = similar(x) + cudnnBNForward!(y, g, b, x, running_mean, running_var, momentum, cache = cache, + alpha = alpha, beta = beta, eps = eps, training = training) + y end function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray{T}, @@ -125,20 +125,20 @@ end function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T}, dx::CuArray{T}, x::CuArray{T}, dy::CuArray{T}, - running_mean::CuArray{T}, running_var::CuArray{T} - momentum; training = true, + running_mean::CuArray{T}, running_var::CuArray{T}, + momentum; training = true, cache = nothing, eps = T(1e-5), alpha = T(1), beta = T(0), dalpha = T(1), dbeta = T(0)) where T<:Union{Float32, Float64} if(training) - + if cache !== nothing mean, ivar = cache.mean, cache.ivar cache_verbose && info("mean and ivar are fetched from the cache") else mean, ivar = C_NULL, C_NULL end - + @check ccall((:cudnnBatchNormalizationBackward, libcudnn), cudnnStatus_t, (cudnnHandle_t,cudnnBatchNormMode_t, Ptr{T}, Ptr{T},