From 88bd8a8fbd17139c9d2a5ef01cb575e73f604bf9 Mon Sep 17 00:00:00 2001 From: Iblis Lin Date: Thu, 2 Nov 2017 13:40:06 +0800 Subject: [PATCH] batchnorm: make CuArrays happy --- src/layers/normalisation.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/layers/normalisation.jl b/src/layers/normalisation.jl index ee606b40..d4a9c94e 100644 --- a/src/layers/normalisation.jl +++ b/src/layers/normalisation.jl @@ -66,7 +66,7 @@ julia> m = Chain( softmax) Chain(Dense(784, 64), BatchNorm(64, λ = NNlib.relu), Dense(64, 10), BatchNorm(10), NNlib.softmax) -julia> opt = SGD(params(m), 10) # a crazy learning rate +julia> opt = SGD(params(m), 10, decay = .1) # a crazy learning rate ``` """ mutable struct BatchNorm{F,V,N} @@ -85,6 +85,8 @@ BatchNorm(dims::Integer...; λ = identity, BatchNorm(λ, param(initβ(dims)), param(initγ(dims)), 0., 1., ϵ, momentum, true) function (BN::BatchNorm)(x) + λ, γ, β = BN.λ, BN.γ, BN.β + if !BN.active μ = BN.μ σ = BN.σ @@ -102,7 +104,7 @@ function (BN::BatchNorm)(x) BN.σ = (1 - mtm) .* BN.σ .+ mtm .* σ.data .* m ./ (m - 1) end - BN.λ.(BN.γ .* ((x .- μ) ./ σ) .+ BN.β) + λ.(γ .* ((x .- μ) ./ σ) .+ β) end children(BN::BatchNorm) =