From d6aacf413584b8d9dcde197972f246e8f7b56c3d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 10 Nov 2018 11:43:49 +0530 Subject: [PATCH] Fix reshape --- src/cuda/cudnn.jl | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index f71742a8..b14b1851 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -121,13 +121,15 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray end end -∇batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T, 2}, dy::CuArray{T}, +function ∇batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T, 2}, dy::CuArray{T, 2}, running_mean::CuArray{T}, running_var::CuArray{T}, momentum; cache = nothing, eps = T(1e-5), alpha = T(1), - beta = T(0), training = true) where T<:Union{Float32, Float64} = - ∇batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), dy, - running_mean, running_var, momentum, cache = cache, eps = eps, alpha = alpha, beta = beta, - training = training) + beta = T(0), training = true) where T<:Union{Float32, Float64} + dg, db, dx = ∇batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), reshape(dy, 1, 1, size(dy, 1), + size(dy, 2)), running_mean, running_var, momentum, cache = cache, eps = eps, + alpha = alpha, beta = beta, training = training) + (dg, db, dropdims(dx, dims = (1, 2))) +end function ∇batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T}, dy::CuArray{T}, running_mean::CuArray{T}, running_var::CuArray{T}, momentum;