diff --git a/src/cuda/cudnn.jl b/src/cuda/cudnn.jl index dd1775ad..088876e4 100644 --- a/src/cuda/cudnn.jl +++ b/src/cuda/cudnn.jl @@ -1,6 +1,6 @@ using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, cudnnTensorDescriptor_t, cudnnBatchNormMode_t, cudnnHandle_t, libcudnn_handle, cudnnDataType, TensorDesc, FilterDesc -import Flux.data +import ..Flux: data mutable struct DropoutDesc ptr::Ptr{Void} @@ -63,7 +63,7 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray end xd = TensorDesc(x) yd = TensorDesc(y) - gd = TensorDesc(T, (1,1,length(g),1)) + gd = TensorDesc(T, dims) if training @@ -136,7 +136,7 @@ function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T}, xd = TensorDesc(x) dyd = TensorDesc(dy) dxd = TensorDesc(dx) - gd = TensorDesc(T, (1,1,length(g),1)) + gd = TensorDesc(T, _wsize(x)) if cache !== nothing mean, ivar = cache.mean, cache.ivar info("mean and ivar are fetched from the cache") @@ -167,9 +167,9 @@ function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T}, gd, g, dg, db, eps, mean, ivar) else - ivar = 1 ./ sqrt.(reshape(running_var, (1, 1, length(running_var), 1)) .+ eps) - dx .= dy .* reshape(g, (1, 1, length(g), 1)) .* ivar - dg .= squeeze(sum(dy .* (x .- reshape(running_mean, (1, 1, length(running_mean), 1))) .* ivar, _reddims(dy)), (1,2,4)) + ivar = 1 ./ sqrt.(reshape(running_var, _wsize(x)) .+ eps) + dx .= dy .* reshape(g, _wsize(x)) .* ivar + dg .= squeeze(sum(dy .* (x .- reshape(running_mean, _wsize(x))) .* ivar, _reddims(dy)), (1,2,4)) db .= squeeze(sum(dy, _reddims(dy)), (1,2,4)) end end @@ -179,6 +179,13 @@ end import ..Flux: Flux import ..Tracker: track, back, @back, istracked, TrackedArray +CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}} +CuParam45{T} = Union{CuParam{T,4},CuParam{T,5}} +CuBatchNorm{T} = Flux.BatchNorm{<:Union{typeof(identity),typeof(relu)},<:CuParam{T,1},<:CuParam{T,1},<:T} + +(BN::BatchNorm)(x::CuParam45{T}) = + batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ, BN.momentum; cache = nothing, alpha = 1, beta = 0, eps = BN.ϵ, training = BN.active) + _batchnorm(g, b, x, running_mean, running_var, momentum, cache, alpha, beta, eps, training) = batchnorm(g, b, x, running_mean, running_var, momentum, cache = cache, alpha = alpha, beta = beta, eps = eps, training = training) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index e43c76b7..04082a73 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -104,7 +104,6 @@ mutable struct BatchNorm{F,V,W,N} σ::W # moving std ϵ::N momentum::N - cache active::Bool end @@ -113,44 +112,39 @@ BatchNorm(chs::Integer, λ = identity; BatchNorm(λ, param(initβ(chs)), param(initγ(chs)), zeros(chs), ones(chs), ϵ, momentum, nothing, true) - -function batchnorm(γ, β, x, μ, σ, momentum; cache = nothing, alpha = 1, beta = 0, eps = 1.0e-5, training = true) - size(x, ndims(x)-1) == length(β) || +function (BN::BatchNorm)(x) + size(x, ndims(x)-1) == length(BN.β) || error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))") + γ, β = BN.γ, BN.β dims = length(size(x)) channels = size(x, dims-1) affine_shape = ones(Int, dims) affine_shape[end-1] = channels m = prod(size(x)[1:end-2]) * size(x)[end] - if !training - μ_curr = reshape(μ, affine_shape...) - σ_curr = reshape(σ, affine_shape...) + if !BN.active + μ = reshape(BN.μ, affine_shape...) + σ = reshape(BN.σ, affine_shape...) else T = eltype(x) - eps = Flux.data(convert(T, eps)) + ϵ = data(convert(T, BN.ϵ)) axes = [1:dims-2; dims] # axes to reduce along (all but channels axis) - μ_curr = mean(x, axes) - σ_curr = sqrt.(mean((x .- μ_curr).^2, axes) .+ eps) + μ = mean(x, axes) + σ = sqrt.(mean((x .- μ).^2, axes) .+ ϵ) # update moving mean/std - mtm = Flux.data(convert(T, momentum)) - μ .= (1 - mtm) .* μ .+ mtm .* squeeze(Flux.data(μ_curr), (axes...)) - σ .= (1 - mtm) .* σ .+ mtm .* squeeze(Flux.data(σ_curr), (axes...)) .* m ./ (m - 1) + mtm = data(convert(T, BN.momentum)) + BN.μ = (1 - mtm) .* BN.μ .+ mtm .* squeeze(data(μ), (axes...)) + BN.σ = (1 - mtm) .* BN.σ .+ mtm .* squeeze(data(σ), (axes...)) .* m ./ (m - 1) + end + + let λ = BN.λ + λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ σ) .+ reshape(β, affine_shape...)) end - reshape(γ, affine_shape...) .* ((x .- μ_curr) ./ σ_curr) .+ reshape(β, affine_shape...) end -(BN::BatchNorm)(x) = BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ, BN.momentum; cache = BN.cache, alpha = 1, beta = 0, eps = BN.ϵ, training = BN.active)) - -Flux.treelike(BatchNorm) - -# children(BN::BatchNorm) = -# (BN.λ, BN.β, BN.γ, BN.μ, BN.σ, BN.ϵ, BN.momentum, BN.active) -# -# mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN) -# BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ), BN.ϵ, BN.momentum, BN.active) +treelike(BatchNorm) _testmode!(BN::BatchNorm, test) = (BN.active = !test)