Fix dimensions
This commit is contained in:
parent
3bc809f49e
commit
e2ae8b4e8d
@ -121,11 +121,11 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray
|
||||
end
|
||||
end
|
||||
|
||||
∇batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T, 2}, dy::CuArray{T, 2},
|
||||
∇batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T, 2}, dy::CuArray{T},
|
||||
running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
|
||||
cache = nothing, eps = T(1e-5), alpha = T(1),
|
||||
beta = T(0), training = true) where T<:Union{Float32, Float64} =
|
||||
∇batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), reshape(dy, 1, 1, size(dy, 1), size(dy, 2)),
|
||||
∇batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), dy,
|
||||
running_mean, running_var, momentum, cache = cache, eps = eps, alpha = alpha, beta = beta,
|
||||
training = training)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user