Fix reshape

This commit is contained in:
Avik Pal 2018-11-10 11:43:49 +05:30
parent e2ae8b4e8d
commit d6aacf4135

View File

@ -121,13 +121,15 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray
end end
end end
∇batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T, 2}, dy::CuArray{T}, function ∇batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T, 2}, dy::CuArray{T, 2},
running_mean::CuArray{T}, running_var::CuArray{T}, momentum; running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
cache = nothing, eps = T(1e-5), alpha = T(1), cache = nothing, eps = T(1e-5), alpha = T(1),
beta = T(0), training = true) where T<:Union{Float32, Float64} = beta = T(0), training = true) where T<:Union{Float32, Float64}
∇batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), dy, dg, db, dx = ∇batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), reshape(dy, 1, 1, size(dy, 1),
running_mean, running_var, momentum, cache = cache, eps = eps, alpha = alpha, beta = beta, size(dy, 2)), running_mean, running_var, momentum, cache = cache, eps = eps,
training = training) alpha = alpha, beta = beta, training = training)
(dg, db, dropdims(dx, dims = (1, 2)))
end
function ∇batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T}, dy::CuArray{T}, function ∇batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T}, dy::CuArray{T},
running_mean::CuArray{T}, running_var::CuArray{T}, momentum; running_mean::CuArray{T}, running_var::CuArray{T}, momentum;