batchnorm fix

This commit is contained in:
Mike J Innes 2018-02-13 14:02:35 +00:00
parent 820cd3ae42
commit 8432d8db06

View File

@ -113,15 +113,15 @@ function (BN::BatchNorm)(x)
else
T = eltype(x)
ϵ = T(BN.ϵ)
ϵ = data(convert(T, BN.ϵ))
m = size(x, 2) # batch size
μ = mean(x, 2)
σ = sqrt.(sum((x .- μ).^2, 2) ./ m .+ ϵ)
# update moving mean/std
mtm = T(BN.momentum)
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* μ.data
BN.σ = (1 - mtm) .* BN.σ .+ mtm .* σ.data .* m ./ (m - 1)
mtm = data(convert(T, BN.momentum))
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* data(μ)
BN.σ = (1 - mtm) .* BN.σ .+ mtm .* data(σ) .* m ./ (m - 1)
end
λ.(γ .* ((x .- μ) ./ σ) .+ β)