LayerNorm tweaks
This commit is contained in:
parent
11d53781b2
commit
b06884b912
@ -38,4 +38,5 @@ These layers don't affect the structure of the network but may improve training
|
|||||||
```@docs
|
```@docs
|
||||||
Flux.testmode!
|
Flux.testmode!
|
||||||
Dropout
|
Dropout
|
||||||
|
LayerNorm
|
||||||
```
|
```
|
||||||
|
@ -7,7 +7,7 @@ module Flux
|
|||||||
using Juno, Requires
|
using Juno, Requires
|
||||||
using Lazy: @forward
|
using Lazy: @forward
|
||||||
|
|
||||||
export Chain, Dense, RNN, LSTM, Dropout,
|
export Chain, Dense, RNN, LSTM, Dropout, LayerNorm,
|
||||||
SGD, ADAM, Momentum, Nesterov,
|
SGD, ADAM, Momentum, Nesterov,
|
||||||
param, params, mapleaves
|
param, params, mapleaves
|
||||||
|
|
||||||
|
@ -80,31 +80,30 @@ function Base.show(io::IO, l::Dense)
|
|||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
ElementwiseLinear(in::Integer)
|
Diagonal(in::Integer)
|
||||||
|
|
||||||
Creates an element-wise linear transformation layer with learnable
|
Creates an element-wise linear transformation layer with learnable
|
||||||
vectors α and β:
|
vectors α and β:
|
||||||
|
|
||||||
y = α .* x .+ b
|
y = α .* x .+ b
|
||||||
|
|
||||||
The input `x` must be a vector of length `in`, or a batch of vectors represented
|
The input `x` must be a array where `size(x, 1) == in`.
|
||||||
as an `in × N` matrix. The out `y` will be a vector or batch of length `in`.
|
|
||||||
"""
|
"""
|
||||||
struct ElementwiseLinear{T}
|
struct Diagonal{T}
|
||||||
α::T
|
α::T
|
||||||
β::T
|
β::T
|
||||||
end
|
end
|
||||||
|
|
||||||
ElementwiseLinear(in::Integer; initα = ones, initβ = zeros) =
|
Diagonal(in::Integer; initα = ones, initβ = zeros) =
|
||||||
ElementwiseLinear(param(initα(in)), param(initβ(in)))
|
Diagonal(param(initα(in)), param(initβ(in)))
|
||||||
|
|
||||||
treelike(ElementwiseLinear)
|
treelike(Diagonal)
|
||||||
|
|
||||||
function (a::ElementwiseLinear)(x)
|
function (a::Diagonal)(x)
|
||||||
α, β = a.α, a.β
|
α, β = a.α, a.β
|
||||||
α.*x .+ β
|
α.*x .+ β
|
||||||
end
|
end
|
||||||
|
|
||||||
function Base.show(io::IO, l::ElementwiseLinear)
|
function Base.show(io::IO, l::Diagonal)
|
||||||
print(io, "ElementwiseLinear(", length(l.α), ")")
|
print(io, "Diagonal(", length(l.α), ")")
|
||||||
end
|
end
|
||||||
|
@ -43,3 +43,25 @@ function (a::Dropout)(x)
|
|||||||
end
|
end
|
||||||
|
|
||||||
_testmode!(a::Dropout, test) = (a.active = !test)
|
_testmode!(a::Dropout, test) = (a.active = !test)
|
||||||
|
|
||||||
|
"""
|
||||||
|
LayerNorm(h::Integer)
|
||||||
|
|
||||||
|
A [normalisation layer](https://arxiv.org/pdf/1607.06450.pdf) designed to be
|
||||||
|
used with recurrent hidden states of size `h`. Normalises the mean/stddev of
|
||||||
|
each input before applying a per-neuron gain/bias.
|
||||||
|
"""
|
||||||
|
struct LayerNorm{T}
|
||||||
|
diag::Diagonal{T}
|
||||||
|
end
|
||||||
|
|
||||||
|
LayerNorm(h::Integer) =
|
||||||
|
LayerNorm(Diagonal(h))
|
||||||
|
|
||||||
|
treelike(LayerNorm)
|
||||||
|
|
||||||
|
(a::LayerNorm)(x) = a.diag(normalise(x))
|
||||||
|
|
||||||
|
function Base.show(io::IO, l::LayerNorm)
|
||||||
|
print(io, "LayerNorm(", length(l.diag.α), ")")
|
||||||
|
end
|
||||||
|
@ -14,24 +14,12 @@ function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat)
|
|||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
layernormalization(α=1.0, β=0.0)
|
normalise(x::AbstractVecOrMat)
|
||||||
|
|
||||||
Creates a normalization layer based on https://arxiv.org/pdf/1607.06450.pdf
|
Normalise each column of `x` to mean 0 and standard deviation 1.
|
||||||
|
|
||||||
The differences are:
|
|
||||||
|
|
||||||
1) std here divides by N-1 (as does std in Julia) vs the paper N
|
|
||||||
2) this layer α and β are constant numbers (i.e. not learnable vectors)
|
|
||||||
|
|
||||||
To achieve the same effect of learnable vectors α and β oe can use
|
|
||||||
the ElementwiseLinear layer
|
|
||||||
"""
|
"""
|
||||||
function layernormalization(α=1.0, β=0.0)
|
function normalise(x::AbstractVecOrMat)
|
||||||
function layer(y)
|
μ′ = mean(x, 1)
|
||||||
_mean = mean(y)
|
σ′ = std(x, 1, mean = μ′)
|
||||||
_std = sqrt.(sum((y.-_mean).^2) ./ (length(y)-1))
|
return (x .- μ′) ./ σ′
|
||||||
_std /= α
|
|
||||||
_mean -= β*_std
|
|
||||||
return (y .- _mean) ./ _std
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user