Integrate cudnn BatchNorm with Flux

This commit is contained in:
Avik Pal 2018-06-20 15:50:30 +05:30
parent 714ca23aba
commit 3339ad5181
2 changed files with 61 additions and 34 deletions

View File

@ -1,5 +1,6 @@
using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, cudnnTensorDescriptor_t,
cudnnBatchNormMode_t, cudnnHandle_t, libcudnn_handle, cudnnDataType, TensorDesc, FilterDesc
import Flux.data
mutable struct DropoutDesc
ptr::Ptr{Void}
@ -27,6 +28,7 @@ const BATCHNORM_ACTIVATION = 0
const BATCHNORM_MIN_EPS = 1e-5
@inline _wsize(y) = ((1 for _=1:ndims(y)-2)..., size(y)[end-1], 1)
@inline _reddims(y) = ((i for i=1:ndims(y)-2)..., ndims(y))
mutable struct bncache
mean
@ -35,15 +37,12 @@ end
bncache() = bncache(nothing, nothing)
(BN::BatchNorm)(x::CuArray{T}; cache = nothing) where T<:Union{Float32, Float64} =
BN.λ.(cudnnBNForward(BN.γ, BN.β, x, BN.μ, BN.σ, BN.momentum, cache = cache, eps = BN.ϵ, training = BN.active))
function cudnnBNForward(g, b, x, running_mean::CuArray{T},
running_var::CuArray{T}, momentum;
function batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T},
running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
cache = nothing, alpha = T(1), beta = T(0),
eps = T(1e-5), training = true) where T<:Union{Float32, Float64}
y = similar(x)
cudnnBNForward!(y, data(g), data(b), data(x), running_mean, running_var, momentum, cache = cache,
cudnnBNForward!(y, g, b, x, running_mean, running_var, momentum, cache = cache,
alpha = alpha, beta = beta, eps = eps, training = training)
y
end
@ -111,23 +110,24 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray
end
end
function cudnnBNBackward(g, b, x::CuArray{T}, dy::CuArray{T}, running_mean::CuArray{T},
running_var::CuArray{T}, momentum;
training = true, cache = nothing, eps = T(1e-5),
alpha = T(1), beta = T(0)) where T<:Union{Float32, Float64}
function ∇batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T}, dy::CuArray{T},
running_mean::CuArray{T}, running_var::CuArray{T}, momentum;
cache = nothing, eps = T(1e-5), alpha = T(1),
beta = T(0), training = true) where T<:Union{Float32, Float64}
dg = similar(g)
db = similar(b)
dx = similar(x)
cudnnBNBackward!(g.grad, data(g), b.grad, dx, x, dy, running_mean, running_var, T(momentum),
cudnnBNBackward!(dg, g, db, dx, x, dy, running_mean, running_var, T(momentum),
training = training, cache = cache, eps = eps, alpha = alpha, beta = beta)
dx
(dx, db, dx)
end
function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
dx::CuArray{T}, x::CuArray{T}, dy::CuArray{T},
running_mean::CuArray{T}, running_var::CuArray{T},
momentum; training = true,
cache = nothing, eps = T(1e-5),
momentum; cache = nothing, eps = T(1e-5),
alpha = T(1), beta = T(0),
dalpha = T(1), dbeta = T(0)) where T<:Union{Float32, Float64}
dalpha = T(1), dbeta = T(0), training = true) where T<:Union{Float32, Float64}
if(training)
xd = TensorDesc(x)
dyd = TensorDesc(dy)
@ -169,3 +169,30 @@ function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
db .= squeeze(sum(dy, _reddims(dy)), (1,2,4))
end
end
# Flux Interface
import Flux.Tracker: track, back, @back, istracked
_batchnorm(g, b, x, running_mean, running_var, momentum,
cache, alpha, beta, eps, training) =
batchnorm(g, b, x, running_mean, running_var, momentum, cache = cache, alpha = alpha, beta = beta, eps = eps, training = training)
batchnorm(g::TrackedArray, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T},
running_var::CuArray{T}, momentum;
cache = nothing, alpha = T(1), beta = T(0),
eps = T(1e-5), training = true) where T<:Union{Float32, Float64} =
track(_batchnorm, g, b, x, running_mean, running_var, momentum, cache, alpha, beta, eps, training)
batchnorm(g::TrackedArray, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T},
running_var::CuArray{T}, momentum; cache = nothing, alpha = T(1), beta = T(0),
eps = T(1e-5), training = true) where T<:Union{Float32, Float64} =
track(_batchnorm, g, b, x, running_mean, running_var, momentum, cache, alpha, beta, eps, training)
function back(::typeof(_batchnorm), Δ, g, b, x, running_mean, running_var, momentum, cache, alpha, beta, eps, training)
deriv_tup = ∇batchnorm(data(g), data(b), data(x), Δ, running_mean, running_var, momentum,
cache = cache, alpha = alpha, beta = beta, eps = eps, training = training)
istracked(x) && @back(x, deriv_tup[1])
@back(b, deriv_tup[2])
@back(g, deriv_tup[3])
end

View File

@ -104,45 +104,45 @@ mutable struct BatchNorm{F,V,W,N}
σ::W # moving std
ϵ::N
momentum::N
cache
active::Bool
end
BatchNorm(chs::Integer, λ = identity;
initβ = zeros, initγ = ones, ϵ = 1e-5, momentum = .1) =
BatchNorm(λ, param(initβ(chs)), param(initγ(chs)),
zeros(chs), ones(chs), ϵ, momentum, true)
zeros(chs), ones(chs), ϵ, momentum, nothing, true)
function (BN::BatchNorm)(x)
size(x, ndims(x)-1) == length(BN.β) ||
function batchnorm(γ, β, x, μ, σ, momentum; cache = nothing, alpha = 1, beta = 0, eps = 1.0e-5, training = true)
size(x, ndims(x)-1) == length(β) ||
error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))")
γ, β = BN.γ, BN.β
dims = length(size(x))
channels = size(x, dims-1)
affine_shape = ones(Int, dims)
affine_shape[end-1] = channels
m = prod(size(x)[1:end-2]) * size(x)[end]
if !BN.active
μ = reshape(BN.μ, affine_shape...)
σ = reshape(BN.σ, affine_shape...)
if !training
μ_curr = reshape(μ, affine_shape...)
σ_curr = reshape(σ, affine_shape...)
else
T = eltype(x)
ϵ = data(convert(T, BN.ϵ))
eps = Flux.data(convert(T, eps))
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
μ = mean(x, axes)
σ = sqrt.(mean((x .- μ).^2, axes) .+ ϵ)
μ_curr = mean(x, axes)
σ_curr = sqrt.(mean((x .- μ_curr).^2, axes) .+ eps)
# update moving mean/std
mtm = data(convert(T, BN.momentum))
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* squeeze(data(μ), (axes...))
BN.σ = (1 - mtm) .* BN.σ .+ mtm .* squeeze(data(σ), (axes...)) .* m ./ (m - 1)
mtm = Flux.data(convert(T, momentum))
μ .= (1 - mtm) .* μ .+ mtm .* squeeze(Flux.data(μ_curr), (axes...))
σ .= (1 - mtm) .* σ .+ mtm .* squeeze(Flux.data(σ_curr), (axes...)) .* m ./ (m - 1)
end
reshape(γ, affine_shape...) .* ((x .- μ_curr) ./ σ_curr) .+ reshape(β, affine_shape...)
end
let λ = BN.λ
λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ σ) .+ reshape(β, affine_shape...))
end
end
(BN::BatchNorm)(x) = BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ, BN.momentum; cache = BN.cache, alpha = 1, beta = 0, eps = BN.ϵ, training = BN.active))
children(BN::BatchNorm) =
(BN.λ, BN.β, BN.γ, BN.μ, BN.σ, BN.ϵ, BN.momentum, BN.active)