Merge pull request #656 from thebhatman/patch-3

Added AlphaDropout which is used in SNNs.
This commit is contained in:
Elliot Saba 2019-03-07 10:58:44 -08:00 committed by GitHub
commit bc12a4d55a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 33 additions and 1 deletions

View File

@ -50,5 +50,6 @@ These layers don't affect the structure of the network but may improve training
Flux.testmode! Flux.testmode!
BatchNorm BatchNorm
Dropout Dropout
AlphaDropout
LayerNorm LayerNorm
``` ```

View File

@ -7,7 +7,7 @@ using MacroTools, Juno, Requires, Reexport, Statistics, Random
using MacroTools: @forward using MacroTools: @forward
export Chain, Dense, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool, export Chain, Dense, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool,
DepthwiseConv, Dropout, LayerNorm, BatchNorm, InstanceNorm, DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm,
params, mapleaves, cpu, gpu, f32, f64 params, mapleaves, cpu, gpu, f32, f64
@reexport using NNlib @reexport using NNlib

View File

@ -43,6 +43,37 @@ end
_testmode!(a::Dropout, test) = (a.active = !test) _testmode!(a::Dropout, test) = (a.active = !test)
"""
AlphaDropout(p)
A dropout layer. It is used in Self-Normalizing Neural Networks.
(https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf)
The AlphaDropout layer ensures that mean and variance of activations remains the same as before.
"""
mutable struct AlphaDropout{F}
p::F
active::Bool
end
function AlphaDropout(p)
@assert 0 p 1
AlphaDropout(p,true)
end
function (a::AlphaDropout)(x)
a.active || return x
λ = eltype(x)(1.0507009873554804934193349852946)
α = eltype(x)(1.6732632423543772848170429916717)
α1 = eltype(x)(-λ*α)
noise = randn(eltype(x), size(x))
x = @. x*(noise > (1 - a.p)) + α1 * (noise <= (1 - a.p))
A = (a.p + a.p * (1 - a.p) * α1 ^ 2)^0.5
B = -A * α1 * (1 - a.p)
x = @. A * x + B
return x
end
_testmode!(a::AlphaDropout, test) = (a.active = !test)
""" """
LayerNorm(h::Integer) LayerNorm(h::Integer)