diff --git a/src/layers/normalisation.jl b/src/layers/normalisation.jl index a018a073..69854f44 100644 --- a/src/layers/normalisation.jl +++ b/src/layers/normalisation.jl @@ -113,15 +113,15 @@ function (BN::BatchNorm)(x) else T = eltype(x) - ϵ = T(BN.ϵ) + ϵ = data(convert(T, BN.ϵ)) m = size(x, 2) # batch size μ = mean(x, 2) σ = sqrt.(sum((x .- μ).^2, 2) ./ m .+ ϵ) # update moving mean/std - mtm = T(BN.momentum) - BN.μ = (1 - mtm) .* BN.μ .+ mtm .* μ.data - BN.σ = (1 - mtm) .* BN.σ .+ mtm .* σ.data .* m ./ (m - 1) + mtm = data(convert(T, BN.momentum)) + BN.μ = (1 - mtm) .* BN.μ .+ mtm .* data(μ) + BN.σ = (1 - mtm) .* BN.σ .+ mtm .* data(σ) .* m ./ (m - 1) end λ.(γ .* ((x .- μ) ./ σ) .+ β)