fix batchnorm
This commit is contained in:
parent
1fc584102d
commit
a140c31f72
@ -135,8 +135,7 @@ function (BN::BatchNorm)(x)
|
|||||||
error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))")
|
error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))")
|
||||||
dims = length(size(x))
|
dims = length(size(x))
|
||||||
channels = size(x, dims-1)
|
channels = size(x, dims-1)
|
||||||
affine_shape = ones(Int, dims)
|
affine_shape = ntuple(i->i == ndims(x) - 1 ? size(x, i) : 1, ndims(x))
|
||||||
affine_shape[end-1] = channels
|
|
||||||
m = prod(size(x)[1:end-2]) * size(x)[end]
|
m = prod(size(x)[1:end-2]) * size(x)[end]
|
||||||
γ = reshape(BN.γ, affine_shape...)
|
γ = reshape(BN.γ, affine_shape...)
|
||||||
β = reshape(BN.β, affine_shape...)
|
β = reshape(BN.β, affine_shape...)
|
||||||
@ -151,9 +150,10 @@ function (BN::BatchNorm)(x)
|
|||||||
σ² = sum((x .- μ) .^ 2, dims = axes) ./ m
|
σ² = sum((x .- μ) .^ 2, dims = axes) ./ m
|
||||||
ϵ = convert(T, BN.ϵ)
|
ϵ = convert(T, BN.ϵ)
|
||||||
# update moving mean/std
|
# update moving mean/std
|
||||||
mtm = convert(T, BN.momentum)
|
mtm = BN.momentum
|
||||||
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* reshape(μ, :)
|
S = eltype(BN.μ)
|
||||||
BN.σ² = (1 - mtm) .* BN.σ² .+ (mtm * m / (m - 1)) .* reshape(σ², :)
|
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* S.(reshape(μ, :))
|
||||||
|
BN.σ² = (1 - mtm) .* BN.σ² .+ (mtm * m / (m - 1)) .* S.(reshape(σ², :))
|
||||||
end
|
end
|
||||||
|
|
||||||
let λ = BN.λ
|
let λ = BN.λ
|
||||||
|
Loading…
Reference in New Issue
Block a user