Update normalise.jl
This commit is contained in:
parent
ebf50f4e1c
commit
6044421c5c
@ -113,34 +113,32 @@ BatchNorm(chs::Integer, λ = identity;
|
|||||||
function (BN::BatchNorm)(x)
|
function (BN::BatchNorm)(x)
|
||||||
size(x, ndims(x)-1) == length(BN.β) ||
|
size(x, ndims(x)-1) == length(BN.β) ||
|
||||||
error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))")
|
error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))")
|
||||||
γ, β = BN.γ, BN.β
|
|
||||||
dims = length(size(x))
|
dims = length(size(x))
|
||||||
channels = size(x, dims-1)
|
channels = size(x, dims-1)
|
||||||
affine_shape = ones(Int, dims)
|
affine_shape = ones(Int, dims)
|
||||||
affine_shape[end-1] = channels
|
affine_shape[end-1] = channels
|
||||||
m = prod(size(x)[1:end-2]) * size(x)[end]
|
m = prod(size(x)[1:end-2]) * size(x)[end]
|
||||||
|
γ = reshape(BN.γ, affine_shape...)
|
||||||
|
β = reshape(BN.β, affine_shape...)
|
||||||
if !BN.active
|
if !BN.active
|
||||||
μ = reshape(BN.μ, affine_shape...)
|
μ = reshape(BN.μ, affine_shape...)
|
||||||
σ² = reshape(BN.σ², affine_shape...)
|
σ² = reshape(BN.σ², affine_shape...)
|
||||||
|
ϵ = BN.ϵ
|
||||||
else
|
else
|
||||||
T = eltype(x)
|
T = eltype(x)
|
||||||
|
|
||||||
ϵ = data(convert(T, BN.ϵ))
|
|
||||||
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
|
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
|
||||||
μ = mean(x, dims = axes)
|
μ = mean(x, dims = axes)
|
||||||
σ² = sum((x .- μ) .^ 2, dims = axes) ./ m
|
σ² = sum((x .- μ) .^ 2, dims = axes) ./ m
|
||||||
|
ϵ = data(convert(T, BN.ϵ))
|
||||||
# update moving mean/std
|
# update moving mean/std
|
||||||
mtm = data(convert(T, BN.momentum))
|
mtm = data(convert(T, BN.momentum))
|
||||||
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* reshape(data(μ), :)
|
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* reshape(data(μ), :)
|
||||||
BN.σ² = ((1 - mtm) .* BN.σ² .+ mtm .* reshape(data(σ²), :) .* m ./ (m - 1))
|
BN.σ² = (1 - mtm) .* BN.σ² .+ mtm .* reshape(data(σ²), :) .* m ./ (m - 1)
|
||||||
end
|
end
|
||||||
|
|
||||||
let λ = BN.λ
|
let λ = BN.λ
|
||||||
temp = reshape(γ, affine_shape...) .* ((x .- μ) ./ sqrt.(σ² .+ BN.ϵ))
|
x̂ = (x .- μ) ./ sqrt.(σ² .+ ϵ)
|
||||||
# This is intentionally not fused because of an extreme slowdown doing so
|
λ.(γ .* x̂ .+ β)
|
||||||
λ.(temp .+ reshape(β, affine_shape...))
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user