Integrate cudnn BatchNorm with Flux
This commit is contained in:
parent
714ca23aba
commit
3339ad5181
@ -1,5 +1,6 @@
|
|||||||
using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, cudnnTensorDescriptor_t,
|
using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, cudnnTensorDescriptor_t,
|
||||||
cudnnBatchNormMode_t, cudnnHandle_t, libcudnn_handle, cudnnDataType, TensorDesc, FilterDesc
|
cudnnBatchNormMode_t, cudnnHandle_t, libcudnn_handle, cudnnDataType, TensorDesc, FilterDesc
|
||||||
|
import Flux.data
|
||||||
|
|
||||||
mutable struct DropoutDesc
|
mutable struct DropoutDesc
|
||||||
ptr::Ptr{Void}
|
ptr::Ptr{Void}
|
||||||
@ -27,6 +28,7 @@ const BATCHNORM_ACTIVATION = 0
|
|||||||
const BATCHNORM_MIN_EPS = 1e-5
|
const BATCHNORM_MIN_EPS = 1e-5
|
||||||
|
|
||||||
@inline _wsize(y) = ((1 for _=1:ndims(y)-2)..., size(y)[end-1], 1)
|
@inline _wsize(y) = ((1 for _=1:ndims(y)-2)..., size(y)[end-1], 1)
|
||||||
|
@inline _reddims(y) = ((i for i=1:ndims(y)-2)..., ndims(y))
|
||||||
|
|
||||||
mutable struct bncache
|
mutable struct bncache
|
||||||
mean
|
mean
|
||||||
@ -35,15 +37,12 @@ end
|
|||||||
|
|
||||||
bncache() = bncache(nothing, nothing)
|
bncache() = bncache(nothing, nothing)
|
||||||
|
|
||||||
(BN::BatchNorm)(x::CuArray{T}; cache = nothing) where T<:Union{Float32, Float64} =
|
function batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T},
|
||||||
BN.λ.(cudnnBNForward(BN.γ, BN.β, x, BN.μ, BN.σ, BN.momentum, cache = cache, eps = BN.ϵ, training = BN.active))
|
running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
|
||||||
|
|
||||||
function cudnnBNForward(g, b, x, running_mean::CuArray{T},
|
|
||||||
running_var::CuArray{T}, momentum;
|
|
||||||
cache = nothing, alpha = T(1), beta = T(0),
|
cache = nothing, alpha = T(1), beta = T(0),
|
||||||
eps = T(1e-5), training = true) where T<:Union{Float32, Float64}
|
eps = T(1e-5), training = true) where T<:Union{Float32, Float64}
|
||||||
y = similar(x)
|
y = similar(x)
|
||||||
cudnnBNForward!(y, data(g), data(b), data(x), running_mean, running_var, momentum, cache = cache,
|
cudnnBNForward!(y, g, b, x, running_mean, running_var, momentum, cache = cache,
|
||||||
alpha = alpha, beta = beta, eps = eps, training = training)
|
alpha = alpha, beta = beta, eps = eps, training = training)
|
||||||
y
|
y
|
||||||
end
|
end
|
||||||
@ -111,23 +110,24 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function cudnnBNBackward(g, b, x::CuArray{T}, dy::CuArray{T}, running_mean::CuArray{T},
|
function ∇batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T}, dy::CuArray{T},
|
||||||
running_var::CuArray{T}, momentum;
|
running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
|
||||||
training = true, cache = nothing, eps = T(1e-5),
|
cache = nothing, eps = T(1e-5), alpha = T(1),
|
||||||
alpha = T(1), beta = T(0)) where T<:Union{Float32, Float64}
|
beta = T(0), training = true) where T<:Union{Float32, Float64}
|
||||||
|
dg = similar(g)
|
||||||
|
db = similar(b)
|
||||||
dx = similar(x)
|
dx = similar(x)
|
||||||
cudnnBNBackward!(g.grad, data(g), b.grad, dx, x, dy, running_mean, running_var, T(momentum),
|
cudnnBNBackward!(dg, g, db, dx, x, dy, running_mean, running_var, T(momentum),
|
||||||
training = training, cache = cache, eps = eps, alpha = alpha, beta = beta)
|
training = training, cache = cache, eps = eps, alpha = alpha, beta = beta)
|
||||||
dx
|
(dx, db, dx)
|
||||||
end
|
end
|
||||||
|
|
||||||
function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
|
function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
|
||||||
dx::CuArray{T}, x::CuArray{T}, dy::CuArray{T},
|
dx::CuArray{T}, x::CuArray{T}, dy::CuArray{T},
|
||||||
running_mean::CuArray{T}, running_var::CuArray{T},
|
running_mean::CuArray{T}, running_var::CuArray{T},
|
||||||
momentum; training = true,
|
momentum; cache = nothing, eps = T(1e-5),
|
||||||
cache = nothing, eps = T(1e-5),
|
|
||||||
alpha = T(1), beta = T(0),
|
alpha = T(1), beta = T(0),
|
||||||
dalpha = T(1), dbeta = T(0)) where T<:Union{Float32, Float64}
|
dalpha = T(1), dbeta = T(0), training = true) where T<:Union{Float32, Float64}
|
||||||
if(training)
|
if(training)
|
||||||
xd = TensorDesc(x)
|
xd = TensorDesc(x)
|
||||||
dyd = TensorDesc(dy)
|
dyd = TensorDesc(dy)
|
||||||
@ -169,3 +169,30 @@ function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
|
|||||||
db .= squeeze(sum(dy, _reddims(dy)), (1,2,4))
|
db .= squeeze(sum(dy, _reddims(dy)), (1,2,4))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
# Flux Interface
|
||||||
|
|
||||||
|
import Flux.Tracker: track, back, @back, istracked
|
||||||
|
|
||||||
|
_batchnorm(g, b, x, running_mean, running_var, momentum,
|
||||||
|
cache, alpha, beta, eps, training) =
|
||||||
|
batchnorm(g, b, x, running_mean, running_var, momentum, cache = cache, alpha = alpha, beta = beta, eps = eps, training = training)
|
||||||
|
|
||||||
|
batchnorm(g::TrackedArray, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T},
|
||||||
|
running_var::CuArray{T}, momentum;
|
||||||
|
cache = nothing, alpha = T(1), beta = T(0),
|
||||||
|
eps = T(1e-5), training = true) where T<:Union{Float32, Float64} =
|
||||||
|
track(_batchnorm, g, b, x, running_mean, running_var, momentum, cache, alpha, beta, eps, training)
|
||||||
|
|
||||||
|
batchnorm(g::TrackedArray, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T},
|
||||||
|
running_var::CuArray{T}, momentum; cache = nothing, alpha = T(1), beta = T(0),
|
||||||
|
eps = T(1e-5), training = true) where T<:Union{Float32, Float64} =
|
||||||
|
track(_batchnorm, g, b, x, running_mean, running_var, momentum, cache, alpha, beta, eps, training)
|
||||||
|
|
||||||
|
function back(::typeof(_batchnorm), Δ, g, b, x, running_mean, running_var, momentum, cache, alpha, beta, eps, training)
|
||||||
|
deriv_tup = ∇batchnorm(data(g), data(b), data(x), Δ, running_mean, running_var, momentum,
|
||||||
|
cache = cache, alpha = alpha, beta = beta, eps = eps, training = training)
|
||||||
|
istracked(x) && @back(x, deriv_tup[1])
|
||||||
|
@back(b, deriv_tup[2])
|
||||||
|
@back(g, deriv_tup[3])
|
||||||
|
end
|
||||||
|
@ -104,45 +104,45 @@ mutable struct BatchNorm{F,V,W,N}
|
|||||||
σ::W # moving std
|
σ::W # moving std
|
||||||
ϵ::N
|
ϵ::N
|
||||||
momentum::N
|
momentum::N
|
||||||
|
cache
|
||||||
active::Bool
|
active::Bool
|
||||||
end
|
end
|
||||||
|
|
||||||
BatchNorm(chs::Integer, λ = identity;
|
BatchNorm(chs::Integer, λ = identity;
|
||||||
initβ = zeros, initγ = ones, ϵ = 1e-5, momentum = .1) =
|
initβ = zeros, initγ = ones, ϵ = 1e-5, momentum = .1) =
|
||||||
BatchNorm(λ, param(initβ(chs)), param(initγ(chs)),
|
BatchNorm(λ, param(initβ(chs)), param(initγ(chs)),
|
||||||
zeros(chs), ones(chs), ϵ, momentum, true)
|
zeros(chs), ones(chs), ϵ, momentum, nothing, true)
|
||||||
|
|
||||||
function (BN::BatchNorm)(x)
|
|
||||||
size(x, ndims(x)-1) == length(BN.β) ||
|
function batchnorm(γ, β, x, μ, σ, momentum; cache = nothing, alpha = 1, beta = 0, eps = 1.0e-5, training = true)
|
||||||
|
size(x, ndims(x)-1) == length(β) ||
|
||||||
error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))")
|
error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))")
|
||||||
γ, β = BN.γ, BN.β
|
|
||||||
dims = length(size(x))
|
dims = length(size(x))
|
||||||
channels = size(x, dims-1)
|
channels = size(x, dims-1)
|
||||||
affine_shape = ones(Int, dims)
|
affine_shape = ones(Int, dims)
|
||||||
affine_shape[end-1] = channels
|
affine_shape[end-1] = channels
|
||||||
m = prod(size(x)[1:end-2]) * size(x)[end]
|
m = prod(size(x)[1:end-2]) * size(x)[end]
|
||||||
|
|
||||||
if !BN.active
|
if !training
|
||||||
μ = reshape(BN.μ, affine_shape...)
|
μ_curr = reshape(μ, affine_shape...)
|
||||||
σ = reshape(BN.σ, affine_shape...)
|
σ_curr = reshape(σ, affine_shape...)
|
||||||
else
|
else
|
||||||
T = eltype(x)
|
T = eltype(x)
|
||||||
|
|
||||||
ϵ = data(convert(T, BN.ϵ))
|
eps = Flux.data(convert(T, eps))
|
||||||
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
|
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
|
||||||
μ = mean(x, axes)
|
μ_curr = mean(x, axes)
|
||||||
σ = sqrt.(mean((x .- μ).^2, axes) .+ ϵ)
|
σ_curr = sqrt.(mean((x .- μ_curr).^2, axes) .+ eps)
|
||||||
|
|
||||||
# update moving mean/std
|
# update moving mean/std
|
||||||
mtm = data(convert(T, BN.momentum))
|
mtm = Flux.data(convert(T, momentum))
|
||||||
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* squeeze(data(μ), (axes...))
|
μ .= (1 - mtm) .* μ .+ mtm .* squeeze(Flux.data(μ_curr), (axes...))
|
||||||
BN.σ = (1 - mtm) .* BN.σ .+ mtm .* squeeze(data(σ), (axes...)) .* m ./ (m - 1)
|
σ .= (1 - mtm) .* σ .+ mtm .* squeeze(Flux.data(σ_curr), (axes...)) .* m ./ (m - 1)
|
||||||
|
end
|
||||||
|
reshape(γ, affine_shape...) .* ((x .- μ_curr) ./ σ_curr) .+ reshape(β, affine_shape...)
|
||||||
end
|
end
|
||||||
|
|
||||||
let λ = BN.λ
|
(BN::BatchNorm)(x) = BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ, BN.momentum; cache = BN.cache, alpha = 1, beta = 0, eps = BN.ϵ, training = BN.active))
|
||||||
λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ σ) .+ reshape(β, affine_shape...))
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
children(BN::BatchNorm) =
|
children(BN::BatchNorm) =
|
||||||
(BN.λ, BN.β, BN.γ, BN.μ, BN.σ, BN.ϵ, BN.momentum, BN.active)
|
(BN.λ, BN.β, BN.γ, BN.μ, BN.σ, BN.ϵ, BN.momentum, BN.active)
|
||||||
|
Loading…
Reference in New Issue
Block a user