Allow multidimensional inputs to batchnorm.

Can be used in conjunction with convolutional layers, in addition
to dense layers, with the same api.
This commit is contained in:
Brad Safnuk 2018-03-15 21:48:59 -04:00
parent 72f13834f5
commit 6653ec86d9
2 changed files with 15 additions and 4 deletions

View File

@ -106,6 +106,11 @@ BatchNorm(dims::Integer...; λ = identity,
function (BN::BatchNorm)(x)
λ, γ, β = BN.λ, BN.γ, BN.β
dims = length(size(x))
channels = size(x, dims-1)
affine_shape = ones(Int, dims)
affine_shape[end-1] = channels
m = prod(size(x)[1:end-2]) * size(x)[end]
if !BN.active
μ = BN.μ
@ -114,9 +119,9 @@ function (BN::BatchNorm)(x)
T = eltype(x)
ϵ = data(convert(T, BN.ϵ))
m = size(x, 2) # batch size
μ = mean(x, 2)
σ = sqrt.(sum((x .- μ).^2, 2) ./ m .+ ϵ)
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
μ = mean(x, axes)
σ = sqrt.(mean((x .- μ).^2, axes) .+ ϵ)
# update moving mean/std
mtm = data(convert(T, BN.momentum))
@ -124,7 +129,7 @@ function (BN::BatchNorm)(x)
BN.σ = (1 - mtm) .* BN.σ .+ mtm .* data(σ) .* m ./ (m - 1)
end
λ.(γ .* ((x .- μ) ./ σ) .+ β)
λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ σ) .+ reshape(β, affine_shape...))
end
children(BN::BatchNorm) =

View File

@ -77,4 +77,10 @@ end
x = m(x).data
@test x[1] σ((1 - 0.3) / 1.1449489742783179)
end
let m = BatchNorm(2), x = param(reshape(1:12, 2, 2, 3))
y = reshape(permutedims(x, [2, 1, 3]), 2, 6)
y = permutedims(reshape(m(y), 2, 2, 3), [2, 1, 3])
@test m(x) == y
end
end