diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 25832c07..e43c76b7 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -144,11 +144,13 @@ end (BN::BatchNorm)(x) = BN.λ.(batchnorm(BN.γ, BN.β, x, BN.μ, BN.σ, BN.momentum; cache = BN.cache, alpha = 1, beta = 0, eps = BN.ϵ, training = BN.active)) -children(BN::BatchNorm) = - (BN.λ, BN.β, BN.γ, BN.μ, BN.σ, BN.ϵ, BN.momentum, BN.active) +Flux.treelike(BatchNorm) -mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN) - BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ), BN.ϵ, BN.momentum, BN.active) +# children(BN::BatchNorm) = +# (BN.λ, BN.β, BN.γ, BN.μ, BN.σ, BN.ϵ, BN.momentum, BN.active) +# +# mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN) +# BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ), BN.ϵ, BN.momentum, BN.active) _testmode!(BN::BatchNorm, test) = (BN.active = !test)