Update normalise.jl

This commit is contained in:
Sklan 2019-02-20 13:47:31 +05:30 committed by GitHub
parent ebf50f4e1c
commit 6044421c5c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -113,34 +113,32 @@ BatchNorm(chs::Integer, λ = identity;
function (BN::BatchNorm)(x) function (BN::BatchNorm)(x)
size(x, ndims(x)-1) == length(BN.β) || size(x, ndims(x)-1) == length(BN.β) ||
error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))") error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))")
γ, β = BN.γ, BN.β
dims = length(size(x)) dims = length(size(x))
channels = size(x, dims-1) channels = size(x, dims-1)
affine_shape = ones(Int, dims) affine_shape = ones(Int, dims)
affine_shape[end-1] = channels affine_shape[end-1] = channels
m = prod(size(x)[1:end-2]) * size(x)[end] m = prod(size(x)[1:end-2]) * size(x)[end]
γ = reshape(BN.γ, affine_shape...)
β = reshape(BN.β, affine_shape...)
if !BN.active if !BN.active
μ = reshape(BN.μ, affine_shape...) μ = reshape(BN.μ, affine_shape...)
σ² = reshape(BN.σ², affine_shape...) σ² = reshape(BN.σ², affine_shape...)
ϵ = BN.ϵ
else else
T = eltype(x) T = eltype(x)
ϵ = data(convert(T, BN.ϵ))
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis) axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
μ = mean(x, dims = axes) μ = mean(x, dims = axes)
σ² = sum((x .- μ) .^ 2, dims = axes) ./ m σ² = sum((x .- μ) .^ 2, dims = axes) ./ m
ϵ = data(convert(T, BN.ϵ))
# update moving mean/std # update moving mean/std
mtm = data(convert(T, BN.momentum)) mtm = data(convert(T, BN.momentum))
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* reshape(data(μ), :) BN.μ = (1 - mtm) .* BN.μ .+ mtm .* reshape(data(μ), :)
BN.σ² = ((1 - mtm) .* BN.σ² .+ mtm .* reshape(data(σ²), :) .* m ./ (m - 1)) BN.σ² = (1 - mtm) .* BN.σ² .+ mtm .* reshape(data(σ²), :) .* m ./ (m - 1)
end end
let λ = BN.λ let λ = BN.λ
temp = reshape(γ, affine_shape...) .* ((x .- μ) ./ sqrt.(σ² .+ BN.ϵ)) = (x .- μ) ./ sqrt.(σ² .+ ϵ)
# This is intentionally not fused because of an extreme slowdown doing so λ.(γ .* .+ β)
λ.(temp .+ reshape(β, affine_shape...))
end end
end end