Merge pull request #95 from FluxML/layernorm

Layer Normalisation
This commit is contained in:
Mike J Innes 2017-11-21 17:12:49 +01:00 committed by GitHub
commit feb35783e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 73 additions and 1 deletions

View File

@ -38,4 +38,5 @@ These layers don't affect the structure of the network but may improve training
```@docs
Flux.testmode!
Dropout
LayerNorm
```

View File

@ -7,7 +7,7 @@ module Flux
using Juno, Requires
using Lazy: @forward
export Chain, Dense, RNN, LSTM, Dropout,
export Chain, Dense, RNN, LSTM, Dropout, LayerNorm,
SGD, ADAM, Momentum, Nesterov,
param, params, mapleaves

View File

@ -78,3 +78,32 @@ function Base.show(io::IO, l::Dense)
l.σ == identity || print(io, ", ", l.σ)
print(io, ")")
end
"""
Diagonal(in::Integer)
Creates an element-wise linear transformation layer with learnable
vectors `α` and `β`:
y = α .* x .+ β
The input `x` must be a array where `size(x, 1) == in`.
"""
struct Diagonal{T}
α::T
β::T
end
Diagonal(in::Integer; initα = ones, initβ = zeros) =
Diagonal(param(initα(in)), param(initβ(in)))
treelike(Diagonal)
function (a::Diagonal)(x)
α, β = a.α, a.β
α.*x .+ β
end
function Base.show(io::IO, l::Diagonal)
print(io, "Diagonal(", length(l.α), ")")
end

View File

@ -43,3 +43,25 @@ function (a::Dropout)(x)
end
_testmode!(a::Dropout, test) = (a.active = !test)
"""
LayerNorm(h::Integer)
A [normalisation layer](https://arxiv.org/pdf/1607.06450.pdf) designed to be
used with recurrent hidden states of size `h`. Normalises the mean/stddev of
each input before applying a per-neuron gain/bias.
"""
struct LayerNorm{T}
diag::Diagonal{T}
end
LayerNorm(h::Integer) =
LayerNorm(Diagonal(h))
treelike(LayerNorm)
(a::LayerNorm)(x) = a.diag(normalise(x))
function Base.show(io::IO, l::LayerNorm)
print(io, "LayerNorm(", length(l.diag.α), ")")
end

View File

@ -12,3 +12,14 @@ function logitcrossentropy(logŷ::AbstractVecOrMat, y::AbstractVecOrMat)
ypred = logŷ .- log.(sum(exp.(logŷ), 1))
-sum(y .* ypred) / size(y, 2)
end
"""
normalise(x::AbstractVecOrMat)
Normalise each column of `x` to mean 0 and standard deviation 1.
"""
function normalise(x::AbstractVecOrMat)
μ′ = mean(x, 1)
σ = std(x, 1, mean = μ′)
return (x .- μ′) ./ σ
end

View File

@ -58,6 +58,12 @@ Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
Base.mean(xs::TrackedArray) = TrackedArray(Call(mean, xs), toarray(xs.data, mean(xs.data)))
Base.mean(xs::TrackedArray, region) = TrackedArray(Call(mean, xs, region))
# Hacks to get std working
Base.std(x::TrackedArray; mean = Base.mean(x)) =
sqrt.(sum((x .- mean).^2) ./ (length(x)-1))
Base.std(x::TrackedArray, dim; mean = Base.mean(x, dim)) =
sqrt.(sum((x .- mean).^2, dim) ./ (size(x, dim)-1))
back(::typeof(mean), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= Δ ./ length(xs.data))
back(::typeof(mean), Δ, xs::TrackedArray, region) =
back(xs, similar(xs.data) .= Δ ./ prod(size(xs.data, region...)))

View File

@ -34,6 +34,9 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest(x -> mean(x, [1, 2]), rand(2, 3, 4))
end
@test gradtest(x -> std(x), rand(5,5))
@test gradtest(x -> std(x, 1), rand(5,5))
@test gradtest(rand(5)) do x
y = x.^2
2y + x