Reshape 2D tensors to use cudnn batchnorm
This commit is contained in:
parent
564518e448
commit
4d703b31a1
@ -39,8 +39,14 @@ end
|
|||||||
|
|
||||||
BNCache() = BNCache(nothing, nothing)
|
BNCache() = BNCache(nothing, nothing)
|
||||||
|
|
||||||
# CuDNN supports only 4D and 5D Tensors for BatchNorm Operations
|
# NOTE: CuDNN supports only 4D and 5D Tensors for BatchNorm Operations
|
||||||
# so use the native julia code when doing batchnorm on a 2D Array
|
# so reshape a 2D Tensor into 4D
|
||||||
|
batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T, 2},
|
||||||
|
running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
|
||||||
|
cache = nothing, alpha = T(1), beta = T(0),
|
||||||
|
eps = T(1e-5), training = true) where T<:Union{Float32, Float64} =
|
||||||
|
batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), running_mean, running_var, momentum,
|
||||||
|
cache = cache, alpha = alpha, beta = beta, eps = eps, training = training)
|
||||||
|
|
||||||
function batchnorm(g::CuArray{T}, b::CuArray{T}, x::Union{CuArray{T, 4},CuArray{T,5}},
|
function batchnorm(g::CuArray{T}, b::CuArray{T}, x::Union{CuArray{T, 4},CuArray{T,5}},
|
||||||
running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
|
running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
|
||||||
@ -115,6 +121,14 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
∇batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T, 2}, dy::CuArray{T, 2},
|
||||||
|
running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
|
||||||
|
cache = nothing, eps = T(1e-5), alpha = T(1),
|
||||||
|
beta = T(0), training = true) where T<:Union{Float32, Float64} =
|
||||||
|
∇batchnorm(g, b, reshape(x, 1, 1, size(x, 1), size(x, 2)), reshape(dy, 1, 1, size(dy, 1), size(dy, 2)),
|
||||||
|
running_mean, running_var, momentum, cache = cache, eps = eps, alpha = alpha, beta = beta,
|
||||||
|
training = training)
|
||||||
|
|
||||||
function ∇batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T}, dy::CuArray{T},
|
function ∇batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T}, dy::CuArray{T},
|
||||||
running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
|
running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
|
||||||
cache = nothing, eps = T(1e-5), alpha = T(1),
|
cache = nothing, eps = T(1e-5), alpha = T(1),
|
||||||
@ -176,7 +190,7 @@ end
|
|||||||
|
|
||||||
# Flux Interface
|
# Flux Interface
|
||||||
|
|
||||||
(BN::Flux.BatchNorm)(x::Union{CuParam{T,4},CuParam{T,5}}, cache = nothing) where T<:Union{Float32, Float64} =
|
(BN::Flux.BatchNorm)(x::Union{CuParam{T,2},CuParam{T,4},CuParam{T,5}}, cache = nothing) where T<:Union{Float32, Float64} =
|
||||||
batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; cache = cache, alpha = 1, beta = 0, eps = BN.ϵ, training = BN.active)
|
batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ², BN.momentum; cache = cache, alpha = 1, beta = 0, eps = BN.ϵ, training = BN.active)
|
||||||
|
|
||||||
batchnorm(g::TrackedArray, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T},
|
batchnorm(g::TrackedArray, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T},
|
||||||
|
Loading…
Reference in New Issue
Block a user