diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index 5d661889..f71742a8 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -121,11 +121,11 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray end end -∇batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T, 2}, dy::CuArray{T, 2}, +∇batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T, 2}, dy::CuArray{T}, running_mean::CuArray{T}, running_var::CuArray{T}, momentum; cache = nothing, eps = T(1e-5), alpha = T(1), beta = T(0), training = true) where T<:Union{Float32, Float64} = - ∇batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), reshape(dy, 1, 1, size(dy, 1), size(dy, 2)), + ∇batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), dy, running_mean, running_var, momentum, cache = cache, eps = eps, alpha = alpha, beta = beta, training = training)