From c4f87ff15c9f4114b06718d35d9ac3fd8bfa35a9 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 11 Sep 2018 16:21:55 +0530 Subject: [PATCH] Minor fixes: --- src/layers/normalise.jl | 57 +++++++++++++++-------------------------- 1 file changed, 21 insertions(+), 36 deletions(-) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 41252bc9..1961fbe3 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -1,6 +1,6 @@ """ - testmode!(m, val=true) - + testmode!(m) + testmode!(m, false) Put layers like [`Dropout`](@ref) and [`BatchNorm`](@ref) into testing mode (or back to training mode with `false`). """ @@ -13,11 +13,9 @@ _testmode!(m, test) = nothing """ Dropout(p) - A Dropout layer. For each input, either sets that input to `0` (with probability `p`) or scales it by `1/(1-p)`. This is used as a regularisation, i.e. it reduces overfitting during training. - Does nothing to the input once in [`testmode!`](@ref). """ mutable struct Dropout{F} @@ -43,9 +41,7 @@ end _testmode!(a::Dropout, test) = (a.active = !test) """ - LayerNorm(h::Integer) - A [normalisation layer](https://arxiv.org/pdf/1607.06450.pdf) designed to be used with recurrent hidden states of size `h`. Normalises the mean/stddev of each input before applying a per-neuron gain/bias. @@ -69,23 +65,17 @@ end BatchNorm(channels::Integer, σ = identity; initβ = zeros, initγ = ones, ϵ = 1e-8, momentum = .1) - Batch Normalization layer. The `channels` input should be the size of the channel dimension in your data (see below). - Given an array with `N` dimensions, call the `N-1`th the channel dimension. (For a batch of feature vectors this is just the data dimension, for `WHCN` images it's the usual channel dimension.) - `BatchNorm` computes the mean and variance for each each `W×H×1×N` slice and shifts them to have a new mean and variance (corresponding to the learnable, per-channel `bias` and `scale` parameters). - See [Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift](https://arxiv.org/pdf/1502.03167.pdf). - Example: - ```julia m = Chain( Dense(28^2, 64), @@ -93,32 +83,23 @@ m = Chain( Dense(64, 10), BatchNorm(10), softmax) - -y = m(rand(28^2, 10)) ``` - -To use the layer at test time set [`testmode!(m, true)`](@ref). """ -mutable struct BatchNorm - λ # activation function - β # bias - γ # scale - μ # moving mean - σ² # moving var - ϵ - momentum +mutable struct BatchNorm{F,V,W,N} + λ::F # activation function + β::V # bias + γ::V # scale + μ::W # moving mean + σ²::W # moving std + ϵ::N + momentum::N active::Bool end -# NOTE: Keeping the ϵ smaller than 1e-5 is not supported by CUDNN -function BatchNorm(chs::Integer, λ = identity; - initβ = (i) -> zeros(i), - initγ = (i) -> ones(i), - ϵ = 1f-5, - momentum = 0.1) +BatchNorm(chs::Integer, λ = identity; + initβ = (i) -> zeros(i), initγ = (i) -> ones(i), ϵ = 1e-5, momentum = .1) = BatchNorm(λ, param(initβ(chs)), param(initγ(chs)), - zeros(Float32, chs), ones(Float32, chs), ϵ, momentum, true) -end + zeros(chs), ones(chs), ϵ, momentum, true) function (BN::BatchNorm)(x) size(x, ndims(x)-1) == length(BN.β) || @@ -132,7 +113,7 @@ function (BN::BatchNorm)(x) if !BN.active μ = reshape(BN.μ, affine_shape...) - σ = reshape(BN.σ, affine_shape...) + σ² = reshape(BN.σ², affine_shape...) else T = eltype(x) @@ -143,8 +124,8 @@ function (BN::BatchNorm)(x) # update moving mean/std mtm = data(convert(T, BN.momentum)) - BN.μ = (1 - mtm) .* BN.μ .+ mtm .* dropdims(data(μ), dims = (axes...,)) - BN.σ² = ((1 - mtm) .* BN.σ² .+ mtm .* dropdims(data(σ²), dims = (axes...)) .* m ./ (m - 1)) + BN.μ = (1 - mtm) .* BN.μ .+ mtm .* dropdims(data(μ), dims = axes) + BN.σ² = ((1 - mtm) .* BN.σ² .+ mtm .* dropdims(data(σ²), dims = axes) .* m ./ (m - 1)) end let λ = BN.λ @@ -152,7 +133,11 @@ function (BN::BatchNorm)(x) end end -@treelike BatchNorm +children(BN::BatchNorm) = + (BN.λ, BN.β, BN.γ, BN.μ, BN.σ, BN.ϵ, BN.momentum, BN.active) + +mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN) + BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ), BN.ϵ, BN.momentum, BN.active) _testmode!(BN::BatchNorm, test) = (BN.active = !test)