Get the batchnorm working without cache

This commit is contained in:
Avik Pal 2018-06-28 12:04:25 +05:30
parent 4916c8e6da
commit 8f43258ab7
2 changed files with 30 additions and 29 deletions

View File

@ -1,6 +1,6 @@
using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, cudnnTensorDescriptor_t, using CuArrays.CUDNN: @check, libcudnn, cudnnStatus_t, cudnnTensorDescriptor_t,
cudnnBatchNormMode_t, cudnnHandle_t, libcudnn_handle, cudnnDataType, TensorDesc, FilterDesc cudnnBatchNormMode_t, cudnnHandle_t, libcudnn_handle, cudnnDataType, TensorDesc, FilterDesc
import Flux.data import ..Flux: data
mutable struct DropoutDesc mutable struct DropoutDesc
ptr::Ptr{Void} ptr::Ptr{Void}
@ -63,7 +63,7 @@ function cudnnBNForward!(y::CuArray{T}, g::CuArray{T}, b::CuArray{T}, x::CuArray
end end
xd = TensorDesc(x) xd = TensorDesc(x)
yd = TensorDesc(y) yd = TensorDesc(y)
gd = TensorDesc(T, (1,1,length(g),1)) gd = TensorDesc(T, dims)
if training if training
@ -136,7 +136,7 @@ function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
xd = TensorDesc(x) xd = TensorDesc(x)
dyd = TensorDesc(dy) dyd = TensorDesc(dy)
dxd = TensorDesc(dx) dxd = TensorDesc(dx)
gd = TensorDesc(T, (1,1,length(g),1)) gd = TensorDesc(T, _wsize(x))
if cache !== nothing if cache !== nothing
mean, ivar = cache.mean, cache.ivar mean, ivar = cache.mean, cache.ivar
info("mean and ivar are fetched from the cache") info("mean and ivar are fetched from the cache")
@ -167,9 +167,9 @@ function cudnnBNBackward!(dg::CuArray{T}, g::CuArray{T}, db::CuArray{T},
gd, g, dg, db, gd, g, dg, db,
eps, mean, ivar) eps, mean, ivar)
else else
ivar = 1 ./ sqrt.(reshape(running_var, (1, 1, length(running_var), 1)) .+ eps) ivar = 1 ./ sqrt.(reshape(running_var, _wsize(x)) .+ eps)
dx .= dy .* reshape(g, (1, 1, length(g), 1)) .* ivar dx .= dy .* reshape(g, _wsize(x)) .* ivar
dg .= squeeze(sum(dy .* (x .- reshape(running_mean, (1, 1, length(running_mean), 1))) .* ivar, _reddims(dy)), (1,2,4)) dg .= squeeze(sum(dy .* (x .- reshape(running_mean, _wsize(x))) .* ivar, _reddims(dy)), (1,2,4))
db .= squeeze(sum(dy, _reddims(dy)), (1,2,4)) db .= squeeze(sum(dy, _reddims(dy)), (1,2,4))
end end
end end
@ -179,6 +179,13 @@ end
import ..Flux: Flux import ..Flux: Flux
import ..Tracker: track, back, @back, istracked, TrackedArray import ..Tracker: track, back, @back, istracked, TrackedArray
CuParam{T,N} = Union{CuArray{T,N},TrackedArray{T,N,CuArray{T,N}}}
CuParam45{T} = Union{CuParam{T,4},CuParam{T,5}}
CuBatchNorm{T} = Flux.BatchNorm{<:Union{typeof(identity),typeof(relu)},<:CuParam{T,1},<:CuParam{T,1},<:T}
(BN::BatchNorm)(x::CuParam45{T}) =
batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ, BN.momentum; cache = nothing, alpha = 1, beta = 0, eps = BN.ϵ, training = BN.active)
_batchnorm(g, b, x, running_mean, running_var, momentum, _batchnorm(g, b, x, running_mean, running_var, momentum,
cache, alpha, beta, eps, training) = cache, alpha, beta, eps, training) =
batchnorm(g, b, x, running_mean, running_var, momentum, cache = cache, alpha = alpha, beta = beta, eps = eps, training = training) batchnorm(g, b, x, running_mean, running_var, momentum, cache = cache, alpha = alpha, beta = beta, eps = eps, training = training)

View File

@ -104,7 +104,6 @@ mutable struct BatchNorm{F,V,W,N}
σ::W # moving std σ::W # moving std
ϵ::N ϵ::N
momentum::N momentum::N
cache
active::Bool active::Bool
end end
@ -113,44 +112,39 @@ BatchNorm(chs::Integer, λ = identity;
BatchNorm(λ, param(initβ(chs)), param(initγ(chs)), BatchNorm(λ, param(initβ(chs)), param(initγ(chs)),
zeros(chs), ones(chs), ϵ, momentum, nothing, true) zeros(chs), ones(chs), ϵ, momentum, nothing, true)
function (BN::BatchNorm)(x)
function batchnorm(γ, β, x, μ, σ, momentum; cache = nothing, alpha = 1, beta = 0, eps = 1.0e-5, training = true) size(x, ndims(x)-1) == length(BN.β) ||
size(x, ndims(x)-1) == length(β) ||
error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))") error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))")
γ, β = BN.γ, BN.β
dims = length(size(x)) dims = length(size(x))
channels = size(x, dims-1) channels = size(x, dims-1)
affine_shape = ones(Int, dims) affine_shape = ones(Int, dims)
affine_shape[end-1] = channels affine_shape[end-1] = channels
m = prod(size(x)[1:end-2]) * size(x)[end] m = prod(size(x)[1:end-2]) * size(x)[end]
if !training if !BN.active
μ_curr = reshape(μ, affine_shape...) μ = reshape(BN.μ, affine_shape...)
σ_curr = reshape(σ, affine_shape...) σ = reshape(BN.σ, affine_shape...)
else else
T = eltype(x) T = eltype(x)
eps = Flux.data(convert(T, eps)) ϵ = data(convert(T, BN.ϵ))
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis) axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
μ_curr = mean(x, axes) μ = mean(x, axes)
σ_curr = sqrt.(mean((x .- μ_curr).^2, axes) .+ eps) σ = sqrt.(mean((x .- μ).^2, axes) .+ ϵ)
# update moving mean/std # update moving mean/std
mtm = Flux.data(convert(T, momentum)) mtm = data(convert(T, BN.momentum))
μ .= (1 - mtm) .* μ .+ mtm .* squeeze(Flux.data(μ_curr), (axes...)) BN.μ = (1 - mtm) .* BN.μ .+ mtm .* squeeze(data(μ), (axes...))
σ .= (1 - mtm) .* σ .+ mtm .* squeeze(Flux.data(σ_curr), (axes...)) .* m ./ (m - 1) BN.σ = (1 - mtm) .* BN.σ .+ mtm .* squeeze(data(σ), (axes...)) .* m ./ (m - 1)
end
reshape(γ, affine_shape...) .* ((x .- μ_curr) ./ σ_curr) .+ reshape(β, affine_shape...)
end end
(BN::BatchNorm)(x) = BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ, BN.momentum; cache = BN.cache, alpha = 1, beta = 0, eps = BN.ϵ, training = BN.active)) let λ = BN.λ
λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ σ) .+ reshape(β, affine_shape...))
end
end
Flux.treelike(BatchNorm) treelike(BatchNorm)
# children(BN::BatchNorm) =
# (BN.λ, BN.β, BN.γ, BN.μ, BN.σ, BN.ϵ, BN.momentum, BN.active)
#
# mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN)
# BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ), BN.ϵ, BN.momentum, BN.active)
_testmode!(BN::BatchNorm, test) = (BN.active = !test) _testmode!(BN::BatchNorm, test) = (BN.active = !test)