diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index 132e105f..c7d997b9 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -1,5 +1,6 @@ using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, cudnnTensorDescriptor_t, cudnnBatchNormMode_t, cudnnHandle_t, libcudnn_handle, cudnnDataType, TensorDesc, FilterDesc +import Flux.data mutable struct DropoutDesc ptr::Ptr{Void} @@ -27,6 +28,7 @@ const BATCHNORM_ACTIVATION = 0 const BATCHNORM_MIN_EPS = 1e-5 @inline _wsize(y) = ((1 for _=1:ndims(y)-2)..., size(y)[end-1], 1) +@inline _reddims(y) = ((i for i=1:ndims(y)-2)..., ndims(y)) mutable struct bncache mean @@ -35,15 +37,12 @@ end bncache() = bncache(nothing, nothing) -(BN::BatchNorm)(x::CuArray{T}; cache = nothing) where T<:Union{Float32, Float64} = - BN.λ.(cudnnBNForward(BN.γ, BN.β, x, BN.μ, BN.σ, BN.momentum, cache = cache, eps = BN.ϵ, training = BN.active)) - -function cudnnBNForward(g, b, x, running_mean::CuArray{T}, - running_var::CuArray{T}, momentum; - cache = nothing, alpha = T(1), beta = T(0), - eps = T(1e-5), training = true) where T<:Union{Float32, Float64} +function batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T}, + running_mean::CuArray{T}, running_var::CuArray{T}, momentum; + cache = nothing, alpha = T(1), beta = T(0), + eps = T(1e-5), training = true) where T<:Union{Float32, Float64} y = similar(x) - cudnnBNForward!(y, data(g), data(b), data(x), running_mean, running_var, momentum, cache = cache, + cudnnBNForward!(y, g, b, x, running_mean, running_var, momentum, cache = cache, alpha = alpha, beta = beta, eps = eps, training = training) y end @@ -111,23 +110,24 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray end end -function cudnnBNBackward(g, b, x::CuArray{T}, dy::CuArray{T}, running_mean::CuArray{T}, - running_var::CuArray{T}, momentum; - training = true, cache = nothing, eps = T(1e-5), - alpha = T(1), beta = T(0)) where T<:Union{Float32, Float64} +function ∇batchnorm(g::CuArray{T}, b::CuArray{T}, x::CuArray{T}, dy::CuArray{T}, + running_mean::CuArray{T}, running_var::CuArray{T}, momentum; + cache = nothing, eps = T(1e-5), alpha = T(1), + beta = T(0), training = true) where T<:Union{Float32, Float64} + dg = similar(g) + db = similar(b) dx = similar(x) - cudnnBNBackward!(g.grad, data(g), b.grad, dx, x, dy, running_mean, running_var, T(momentum), + cudnnBNBackward!(dg, g, db, dx, x, dy, running_mean, running_var, T(momentum), training = training, cache = cache, eps = eps, alpha = alpha, beta = beta) - dx + (dx, db, dx) end function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T}, dx::CuArray{T}, x::CuArray{T}, dy::CuArray{T}, running_mean::CuArray{T}, running_var::CuArray{T}, - momentum; training = true, - cache = nothing, eps = T(1e-5), + momentum; cache = nothing, eps = T(1e-5), alpha = T(1), beta = T(0), - dalpha = T(1), dbeta = T(0)) where T<:Union{Float32, Float64} + dalpha = T(1), dbeta = T(0), training = true) where T<:Union{Float32, Float64} if(training) xd = TensorDesc(x) dyd = TensorDesc(dy) @@ -169,3 +169,30 @@ function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T}, db .= squeeze(sum(dy, _reddims(dy)), (1,2,4)) end end + +# Flux Interface + +import Flux.Tracker: track, back, @back, istracked + +_batchnorm(g, b, x, running_mean, running_var, momentum, + cache, alpha, beta, eps, training) = + batchnorm(g, b, x, running_mean, running_var, momentum, cache = cache, alpha = alpha, beta = beta, eps = eps, training = training) + +batchnorm(g::TrackedArray, b::TrackedArray, x::TrackedArray, running_mean::CuArray{T}, + running_var::CuArray{T}, momentum; + cache = nothing, alpha = T(1), beta = T(0), + eps = T(1e-5), training = true) where T<:Union{Float32, Float64} = + track(_batchnorm, g, b, x, running_mean, running_var, momentum, cache, alpha, beta, eps, training) + +batchnorm(g::TrackedArray, b::TrackedArray, x::CuArray{T}, running_mean::CuArray{T}, + running_var::CuArray{T}, momentum; cache = nothing, alpha = T(1), beta = T(0), + eps = T(1e-5), training = true) where T<:Union{Float32, Float64} = + track(_batchnorm, g, b, x, running_mean, running_var, momentum, cache, alpha, beta, eps, training) + +function back(::typeof(_batchnorm), Δ, g, b, x, running_mean, running_var, momentum, cache, alpha, beta, eps, training) + deriv_tup = ∇batchnorm(data(g), data(b), data(x), Δ, running_mean, running_var, momentum, + cache = cache, alpha = alpha, beta = beta, eps = eps, training = training) + istracked(x) && @back(x, deriv_tup[1]) + @back(b, deriv_tup[2]) + @back(g, deriv_tup[3]) +end diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 5e363454..25832c07 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -104,46 +104,46 @@ mutable struct BatchNorm{F,V,W,N} σ::W # moving std ϵ::N momentum::N + cache active::Bool end BatchNorm(chs::Integer, λ = identity; initβ = zeros, initγ = ones, ϵ = 1e-5, momentum = .1) = BatchNorm(λ, param(initβ(chs)), param(initγ(chs)), - zeros(chs), ones(chs), ϵ, momentum, true) + zeros(chs), ones(chs), ϵ, momentum, nothing, true) -function (BN::BatchNorm)(x) - size(x, ndims(x)-1) == length(BN.β) || + +function batchnorm(γ, β, x, μ, σ, momentum; cache = nothing, alpha = 1, beta = 0, eps = 1.0e-5, training = true) + size(x, ndims(x)-1) == length(β) || error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))") - γ, β = BN.γ, BN.β dims = length(size(x)) channels = size(x, dims-1) affine_shape = ones(Int, dims) affine_shape[end-1] = channels m = prod(size(x)[1:end-2]) * size(x)[end] - if !BN.active - μ = reshape(BN.μ, affine_shape...) - σ = reshape(BN.σ, affine_shape...) + if !training + μ_curr = reshape(μ, affine_shape...) + σ_curr = reshape(σ, affine_shape...) else T = eltype(x) - ϵ = data(convert(T, BN.ϵ)) + eps = Flux.data(convert(T, eps)) axes = [1:dims-2; dims] # axes to reduce along (all but channels axis) - μ = mean(x, axes) - σ = sqrt.(mean((x .- μ).^2, axes) .+ ϵ) + μ_curr = mean(x, axes) + σ_curr = sqrt.(mean((x .- μ_curr).^2, axes) .+ eps) # update moving mean/std - mtm = data(convert(T, BN.momentum)) - BN.μ = (1 - mtm) .* BN.μ .+ mtm .* squeeze(data(μ), (axes...)) - BN.σ = (1 - mtm) .* BN.σ .+ mtm .* squeeze(data(σ), (axes...)) .* m ./ (m - 1) - end - - let λ = BN.λ - λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ σ) .+ reshape(β, affine_shape...)) + mtm = Flux.data(convert(T, momentum)) + μ .= (1 - mtm) .* μ .+ mtm .* squeeze(Flux.data(μ_curr), (axes...)) + σ .= (1 - mtm) .* σ .+ mtm .* squeeze(Flux.data(σ_curr), (axes...)) .* m ./ (m - 1) end + reshape(γ, affine_shape...) .* ((x .- μ_curr) ./ σ_curr) .+ reshape(β, affine_shape...) end +(BN::BatchNorm)(x) = BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ, BN.momentum; cache = BN.cache, alpha = 1, beta = 0, eps = BN.ϵ, training = BN.active)) + children(BN::BatchNorm) = (BN.λ, BN.β, BN.γ, BN.μ, BN.σ, BN.ϵ, BN.momentum, BN.active)