This commit is contained in:
Avik Pal 2018-06-28 14:45:35 +05:30
parent d0b79e71e2
commit bcf094451c

View File

@ -123,7 +123,7 @@ function ∇batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T}, dy::CuArray{T
dx = similar(x)
cudnnBNBackward!(dg, g, db, dx, x, dy, running_mean, running_var, T(momentum),
training = training, cache = cache, eps = eps, alpha = alpha, beta = beta)
(dx, db, dx)
(dx, db, dg)
end
function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
@ -184,7 +184,8 @@ import ..Tracker: track, back, @back, istracked, TrackedArray
_batchnorm(g, b, x, running_mean, running_var, momentum,
cache, alpha, beta, eps, training) =
batchnorm(g, b, x, running_mean, running_var, momentum, cache = cache, alpha = alpha, beta = beta, eps = eps, training = training)
batchnorm(g, b, x, running_mean, running_var, momentum, cache = cache,
alpha = alpha, beta = beta, eps = eps, training = training)
batchnorm(g::TrackedArray, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T},
running_var::CuArray{T}, momentum;