Fix missing parenthesis

This commit is contained in:
Avik Pal 2018-06-12 21:32:56 +05:30
parent f12e367cab
commit 24d13ac326

View File

@ -41,24 +41,24 @@ const BATCHNORM_MIN_EPS = 1e-5
@inline _wsize(y) = ((1 for _=1:ndims(y)-2)..., size(y)[end-1], 1) @inline _wsize(y) = ((1 for _=1:ndims(y)-2)..., size(y)[end-1], 1)
mutable struct bncache mutable struct bncache
mean mean
ivar ivar
end end
bncache() = bncache(nothing, nothing) bncache() = bncache(nothing, nothing)
(CuBN::CuBatchNorm)(x::CuArray{T}; cache = nothing) where T<:Union{Float32, Float64} = (CuBN::CuBatchNorm)(x::CuArray{T}; cache = nothing) where T<:Union{Float32, Float64} =
CuBN.λ.(cudnnBNForward(CuBN.γ, CuBN.β, x, CuBN.μ, CuBN.σ, CuBN.momentum, cache = cache, eps = CuBN.ϵ, training = CuBN.active)) CuBN.λ.(cudnnBNForward(CuBN.γ, CuBN.β, x, CuBN.μ, CuBN.σ, CuBN.momentum, cache = cache, eps = CuBN.ϵ, training = CuBN.active))
function cudnnBNForward(g::CuArray{T}, b::CuArray{T}, x::CuArray{T}, function cudnnBNForward(g::CuArray{T}, b::CuArray{T}, x::CuArray{T},
running_mean::CuArray{T}, running_var::CuArray{T}, running_mean::CuArray{T}, running_var::CuArray{T},
momentum::T; cache = nothing, momentum::T; cache = nothing,
alpha = T(1), beta = T(0), alpha = T(1), beta = T(0),
eps = T(1e-5), training = true) where T<:Union{Float32, Float64} eps = T(1e-5), training = true) where T<:Union{Float32, Float64}
y = similar(x) y = similar(x)
cudnnBNForward!(y, g, b, x, running_mean, running_var, momentum, cache = cache cudnnBNForward!(y, g, b, x, running_mean, running_var, momentum, cache = cache,
alpha = alpha, beta = beta, eps = eps, training = training) alpha = alpha, beta = beta, eps = eps, training = training)
y y
end end
function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray{T}, function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray{T},
@ -125,20 +125,20 @@ end
function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T}, function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
dx::CuArray{T}, x::CuArray{T}, dy::CuArray{T}, dx::CuArray{T}, x::CuArray{T}, dy::CuArray{T},
running_mean::CuArray{T}, running_var::CuArray{T} running_mean::CuArray{T}, running_var::CuArray{T},
momentum; training = true, momentum; training = true,
cache = nothing, eps = T(1e-5), cache = nothing, eps = T(1e-5),
alpha = T(1), beta = T(0), alpha = T(1), beta = T(0),
dalpha = T(1), dbeta = T(0)) where T<:Union{Float32, Float64} dalpha = T(1), dbeta = T(0)) where T<:Union{Float32, Float64}
if(training) if(training)
if cache !== nothing if cache !== nothing
mean, ivar = cache.mean, cache.ivar mean, ivar = cache.mean, cache.ivar
cache_verbose && info("mean and ivar are fetched from the cache") cache_verbose && info("mean and ivar are fetched from the cache")
else else
mean, ivar = C_NULL, C_NULL mean, ivar = C_NULL, C_NULL
end end
@check ccall((:cudnnBatchNormalizationBackward, libcudnn), cudnnStatus_t, @check ccall((:cudnnBatchNormalizationBackward, libcudnn), cudnnStatus_t,
(cudnnHandle_t,cudnnBatchNormMode_t, (cudnnHandle_t,cudnnBatchNormMode_t,
Ptr{T}, Ptr{T}, Ptr{T}, Ptr{T},