diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index d5c2de09..6d15fa61 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -123,7 +123,7 @@ function ∇batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T}, dy::CuArray{T dx = similar(x) cudnnBNBackward!(dg, g, db, dx, x, dy, running_mean, running_var, T(momentum), training = training, cache = cache, eps = eps, alpha = alpha, beta = beta) - (dx, db, dx) + (dx, db, dg) end function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T}, @@ -184,7 +184,8 @@ import ..Tracker: track, back, @back, istracked, TrackedArray _batchnorm(g, b, x, running_mean, running_var, momentum, cache, alpha, beta, eps, training) = - batchnorm(g, b, x, running_mean, running_var, momentum, cache = cache, alpha = alpha, beta = beta, eps = eps, training = training) + batchnorm(g, b, x, running_mean, running_var, momentum, cache = cache, + alpha = alpha, beta = beta, eps = eps, training = training) batchnorm(g::TrackedArray, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T}, running_var::CuArray{T}, momentum;