From 8432d8db063e8ab76f66c7c41b3bd89d91535558 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 13 Feb 2018 14:02:35 +0000 Subject: [PATCH] batchnorm fix --- src/layers/normalisation.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/layers/normalisation.jl b/src/layers/normalisation.jl index a018a073..69854f44 100644 --- a/src/layers/normalisation.jl +++ b/src/layers/normalisation.jl @@ -113,15 +113,15 @@ function (BN::BatchNorm)(x) else T = eltype(x) - ϵ = T(BN.ϵ) + ϵ = data(convert(T, BN.ϵ)) m = size(x, 2) # batch size μ = mean(x, 2) σ = sqrt.(sum((x .- μ).^2, 2) ./ m .+ ϵ) # update moving mean/std - mtm = T(BN.momentum) - BN.μ = (1 - mtm) .* BN.μ .+ mtm .* μ.data - BN.σ = (1 - mtm) .* BN.σ .+ mtm .* σ.data .* m ./ (m - 1) + mtm = data(convert(T, BN.momentum)) + BN.μ = (1 - mtm) .* BN.μ .+ mtm .* data(μ) + BN.σ = (1 - mtm) .* BN.σ .+ mtm .* data(σ) .* m ./ (m - 1) end λ.(γ .* ((x .- μ) ./ σ) .+ β)