Fix missing parenthesis
This commit is contained in:
parent
f12e367cab
commit
24d13ac326
@ -41,24 +41,24 @@ const BATCHNORM_MIN_EPS = 1e-5
|
||||
@inline _wsize(y) = ((1 for _=1:ndims(y)-2)..., size(y)[end-1], 1)
|
||||
|
||||
mutable struct bncache
|
||||
mean
|
||||
ivar
|
||||
mean
|
||||
ivar
|
||||
end
|
||||
|
||||
bncache() = bncache(nothing, nothing)
|
||||
|
||||
(CuBN::CuBatchNorm)(x::CuArray{T}; cache = nothing) where T<:Union{Float32, Float64} =
|
||||
CuBN.λ.(cudnnBNForward(CuBN.γ, CuBN.β, x, CuBN.μ, CuBN.σ, CuBN.momentum, cache = cache, eps = CuBN.ϵ, training = CuBN.active))
|
||||
CuBN.λ.(cudnnBNForward(CuBN.γ, CuBN.β, x, CuBN.μ, CuBN.σ, CuBN.momentum, cache = cache, eps = CuBN.ϵ, training = CuBN.active))
|
||||
|
||||
function cudnnBNForward(g::CuArray{T}, b::CuArray{T}, x::CuArray{T},
|
||||
running_mean::CuArray{T}, running_var::CuArray{T},
|
||||
momentum::T; cache = nothing,
|
||||
alpha = T(1), beta = T(0),
|
||||
eps = T(1e-5), training = true) where T<:Union{Float32, Float64}
|
||||
y = similar(x)
|
||||
cudnnBNForward!(y, g, b, x, running_mean, running_var, momentum, cache = cache
|
||||
alpha = alpha, beta = beta, eps = eps, training = training)
|
||||
y
|
||||
y = similar(x)
|
||||
cudnnBNForward!(y, g, b, x, running_mean, running_var, momentum, cache = cache,
|
||||
alpha = alpha, beta = beta, eps = eps, training = training)
|
||||
y
|
||||
end
|
||||
|
||||
function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray{T},
|
||||
@ -125,7 +125,7 @@ end
|
||||
|
||||
function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
|
||||
dx::CuArray{T}, x::CuArray{T}, dy::CuArray{T},
|
||||
running_mean::CuArray{T}, running_var::CuArray{T}
|
||||
running_mean::CuArray{T}, running_var::CuArray{T},
|
||||
momentum; training = true,
|
||||
cache = nothing, eps = T(1e-5),
|
||||
alpha = T(1), beta = T(0),
|
||||
|
Loading…
Reference in New Issue
Block a user