Make changes as per the review
This commit is contained in:
parent
f29377123e
commit
24ba1c4e6c
@ -27,15 +27,16 @@ const BATCHNORM_SPATIAL = 1
|
||||
const BATCHNORM_ACTIVATION = 0
|
||||
const BATCHNORM_MIN_EPS = 1e-5
|
||||
|
||||
@inline _wsize(y) = ((1 for _=1:ndims(y)-2)..., size(y)[end-1], 1)
|
||||
@inline _reddims(y) = ((i for i=1:ndims(y)-2)..., ndims(y))
|
||||
@inline _wsize(y) = (map(_ -> 1, size(y)[1:end-2])..., size(y)[end-1], 1)
|
||||
|
||||
mutable struct bncache
|
||||
@inline _reddims(y) = (collect(1:ndims(y)-2)..., ndims(y))
|
||||
|
||||
mutable struct BNCache
|
||||
mean
|
||||
ivar
|
||||
end
|
||||
|
||||
bncache() = bncache(nothing, nothing)
|
||||
BNCache() = BNCache(nothing, nothing)
|
||||
|
||||
# CuDNN supports only 4D and 5D Tensors for BatchNorm Operations
|
||||
# so use the native julia code when doing batchnorm on a 2D Array
|
||||
@ -56,7 +57,7 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray
|
||||
alpha = T(1), beta = T(0),
|
||||
eps = T(1e-5), training = true) where T<:Union{Float32, Float64}
|
||||
dims = _wsize(x)
|
||||
if(eps < BATCHNORM_MIN_EPS)
|
||||
if eps < BATCHNORM_MIN_EPS
|
||||
warn("eps ",eps," is too small for CuDNN so eps has been assigned the value ", BATCHNORM_MIN_EPS)
|
||||
eps = BATCHNORM_MIN_EPS
|
||||
end
|
||||
@ -64,11 +65,11 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray
|
||||
yd = TensorDesc(y)
|
||||
gd = TensorDesc(T, (1,1,length(g),1))
|
||||
|
||||
if(training)
|
||||
if training
|
||||
|
||||
if(cache !== nothing)
|
||||
mean = cu(zeros(T, dims...))
|
||||
ivar = cu(ones(T, dims...))
|
||||
if cache !== nothing
|
||||
mean = zeros(CuArray{T}, dims...)
|
||||
ivar = ones(CuArray{T}, dims...)
|
||||
else
|
||||
mean = C_NULL
|
||||
ivar = C_NULL
|
||||
@ -90,7 +91,7 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray
|
||||
momentum, running_mean, running_var,
|
||||
eps, mean, ivar)
|
||||
|
||||
if(cache !== nothing)
|
||||
if cache !== nothing
|
||||
cache.mean = mean
|
||||
cache.ivar = ivar
|
||||
end
|
||||
@ -131,7 +132,7 @@ function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
|
||||
momentum; cache = nothing, eps = T(1e-5),
|
||||
alpha = T(1), beta = T(0),
|
||||
dalpha = T(1), dbeta = T(0), training = true) where T<:Union{Float32, Float64}
|
||||
if(training)
|
||||
if training
|
||||
xd = TensorDesc(x)
|
||||
dyd = TensorDesc(dy)
|
||||
dxd = TensorDesc(dx)
|
||||
@ -143,7 +144,7 @@ function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
|
||||
mean, ivar = C_NULL, C_NULL
|
||||
end
|
||||
|
||||
if(eps < BATCHNORM_MIN_EPS)
|
||||
if eps < BATCHNORM_MIN_EPS
|
||||
warn("eps ",eps," is too small for CuDNN so eps has been assigned the value ", BATCHNORM_MIN_EPS)
|
||||
eps = BATCHNORM_MIN_EPS
|
||||
end
|
||||
@ -175,7 +176,8 @@ end
|
||||
|
||||
# Flux Interface
|
||||
|
||||
import Flux.Tracker: track, back, @back, istracked
|
||||
import ..Flux: Flux
|
||||
import ..Tracker: track, back, @back, istracked, TrackedArray
|
||||
|
||||
_batchnorm(g, b, x, running_mean, running_var, momentum,
|
||||
cache, alpha, beta, eps, training) =
|
||||
@ -195,7 +197,7 @@ batchnorm(g::TrackedArray, b::TrackedArray, x::CuArray{T}, running_mean::CuArray
|
||||
function back(::typeof(_batchnorm), Δ, g, b, x, running_mean, running_var, momentum, cache, alpha, beta, eps, training)
|
||||
deriv_tup = ∇batchnorm(data(g), data(b), data(x), Δ, running_mean, running_var, momentum,
|
||||
cache = cache, alpha = alpha, beta = beta, eps = eps, training = training)
|
||||
istracked(x) && @back(x, deriv_tup[1])
|
||||
@back(x, deriv_tup[1])
|
||||
@back(b, deriv_tup[2])
|
||||
@back(g, deriv_tup[3])
|
||||
end
|
||||
|
Loading…
Reference in New Issue
Block a user