This commit is contained in:
Mike J Innes 2019-03-08 12:56:19 +00:00 committed by Elliot Saba
parent c313be8e95
commit 82ee61f5be
3 changed files with 24 additions and 50 deletions

View File

@ -5,15 +5,13 @@ module Flux
using Base: tail using Base: tail
using MacroTools, Juno, Requires, Reexport, Statistics, Random using MacroTools, Juno, Requires, Reexport, Statistics, Random
using MacroTools: @forward using MacroTools: @forward
@reexport using NNlib
using Zygote: Params, @adjoint, gradient
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool, export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool,
DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm, DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm,
params, mapleaves, cpu, gpu, f32, f64 params, mapleaves, cpu, gpu, f32, f64
@reexport using NNlib
using Zygote
include("optimise/Optimise.jl") include("optimise/Optimise.jl")
using .Optimise using .Optimise
using .Optimise: @epochs using .Optimise: @epochs

View File

@ -1,16 +1,6 @@
""" istraining() = false
testmode!(m)
testmode!(m, false)
Put layers like [`Dropout`](@ref) and [`BatchNorm`](@ref) into testing mode @adjoint istraining() = true, _ -> nothing
(or back to training mode with `false`).
"""
function testmode!(m, val::Bool=true)
prefor(x -> _testmode!(x, val), m)
return m
end
_testmode!(m, test) = nothing
""" """
Dropout(p) Dropout(p)
@ -23,26 +13,22 @@ Does nothing to the input once in [`testmode!`](@ref).
""" """
mutable struct Dropout{F} mutable struct Dropout{F}
p::F p::F
active::Bool function Dropout(p)
end
function Dropout(p)
@assert 0 p 1 @assert 0 p 1
Dropout{typeof(p)}(p, true) new{typeof(p)}(p)
end
end end
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0) _dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
function (a::Dropout)(x) function (a::Dropout)(x)
a.active || return x istraining() || return x
y = similar(x) y = similar(x)
rand!(y) rand!(y)
y .= _dropout_kernel.(y, a.p, 1 - a.p) y .= _dropout_kernel.(y, a.p, 1 - a.p)
return x .* y return x .* y
end end
_testmode!(a::Dropout, test) = (a.active = !test)
""" """
AlphaDropout(p) AlphaDropout(p)
A dropout layer. It is used in Self-Normalizing Neural Networks. A dropout layer. It is used in Self-Normalizing Neural Networks.
@ -51,16 +37,14 @@ The AlphaDropout layer ensures that mean and variance of activations remains the
""" """
mutable struct AlphaDropout{F} mutable struct AlphaDropout{F}
p::F p::F
active::Bool function AlphaDropout(p)
end
function AlphaDropout(p)
@assert 0 p 1 @assert 0 p 1
AlphaDropout(p,true) new{typeof(p)}(p)
end
end end
function (a::AlphaDropout)(x) function (a::AlphaDropout)(x)
a.active || return x istraining() || return x
λ = eltype(x)(1.0507009873554804934193349852946) λ = eltype(x)(1.0507009873554804934193349852946)
α = eltype(x)(1.6732632423543772848170429916717) α = eltype(x)(1.6732632423543772848170429916717)
α1 = eltype(x)(-λ*α) α1 = eltype(x)(-λ*α)
@ -72,8 +56,6 @@ function (a::AlphaDropout)(x)
return x return x
end end
_testmode!(a::AlphaDropout, test) = (a.active = !test)
""" """
LayerNorm(h::Integer) LayerNorm(h::Integer)
@ -133,13 +115,12 @@ mutable struct BatchNorm{F,V,W,N}
σ²::W # moving std σ²::W # moving std
ϵ::N ϵ::N
momentum::N momentum::N
active::Bool
end end
BatchNorm(chs::Integer, λ = identity; BatchNorm(chs::Integer, λ = identity;
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) = initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
BatchNorm(λ, initβ(chs), initγ(chs), BatchNorm(λ, initβ(chs), initγ(chs),
zeros(chs), ones(chs), ϵ, momentum, true) zeros(chs), ones(chs), ϵ, momentum)
function (BN::BatchNorm)(x) function (BN::BatchNorm)(x)
size(x, ndims(x)-1) == length(BN.β) || size(x, ndims(x)-1) == length(BN.β) ||
@ -151,7 +132,7 @@ function (BN::BatchNorm)(x)
m = prod(size(x)[1:end-2]) * size(x)[end] m = prod(size(x)[1:end-2]) * size(x)[end]
γ = reshape(BN.γ, affine_shape...) γ = reshape(BN.γ, affine_shape...)
β = reshape(BN.β, affine_shape...) β = reshape(BN.β, affine_shape...)
if !BN.active if !istraining()
μ = reshape(BN.μ, affine_shape...) μ = reshape(BN.μ, affine_shape...)
σ² = reshape(BN.σ², affine_shape...) σ² = reshape(BN.σ², affine_shape...)
ϵ = BN.ϵ ϵ = BN.ϵ
@ -174,12 +155,10 @@ function (BN::BatchNorm)(x)
end end
children(BN::BatchNorm) = children(BN::BatchNorm) =
(BN.λ, BN.β, BN.γ, BN.μ, BN.σ², BN.ϵ, BN.momentum, BN.active) (BN.λ, BN.β, BN.γ, BN.μ, BN.σ², BN.ϵ, BN.momentum)
mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN) mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN)
BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ²), BN.ϵ, BN.momentum, BN.active) BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ²), BN.ϵ, BN.momentum)
_testmode!(BN::BatchNorm, test) = (BN.active = !test)
function Base.show(io::IO, l::BatchNorm) function Base.show(io::IO, l::BatchNorm)
print(io, "BatchNorm($(join(size(l.β), ", "))") print(io, "BatchNorm($(join(size(l.β), ", "))")
@ -226,13 +205,12 @@ mutable struct InstanceNorm{F,V,W,N}
σ²::W # moving std σ²::W # moving std
ϵ::N ϵ::N
momentum::N momentum::N
active::Bool
end end
InstanceNorm(chs::Integer, λ = identity; InstanceNorm(chs::Integer, λ = identity;
initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) = initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) =
InstanceNorm(λ, initβ(chs), initγ(chs), InstanceNorm(λ, initβ(chs), initγ(chs),
zeros(chs), ones(chs), ϵ, momentum, true) zeros(chs), ones(chs), ϵ, momentum)
function (in::InstanceNorm)(x) function (in::InstanceNorm)(x)
size(x, ndims(x)-1) == length(in.β) || size(x, ndims(x)-1) == length(in.β) ||
@ -249,7 +227,7 @@ function (in::InstanceNorm)(x)
m = prod(size(x)[1:end-2]) m = prod(size(x)[1:end-2])
γ, β = expand_inst(in.γ, affine_shape), expand_inst(in.β, affine_shape) γ, β = expand_inst(in.γ, affine_shape), expand_inst(in.β, affine_shape)
if !in.active if !istraining()
μ = expand_inst(in.μ, affine_shape) μ = expand_inst(in.μ, affine_shape)
σ² = expand_inst(in.σ², affine_shape) σ² = expand_inst(in.σ², affine_shape)
ϵ = in.ϵ ϵ = in.ϵ
@ -274,12 +252,10 @@ function (in::InstanceNorm)(x)
end end
children(in::InstanceNorm) = children(in::InstanceNorm) =
(in.λ, in.β, in.γ, in.μ, in.σ², in.ϵ, in.momentum, in.active) (in.λ, in.β, in.γ, in.μ, in.σ², in.ϵ, in.momentum)
mapchildren(f, in::InstanceNorm) = # e.g. mapchildren(cu, in) mapchildren(f, in::InstanceNorm) = # e.g. mapchildren(cu, in)
InstanceNorm(in.λ, f(in.β), f(in.γ), f(in.μ), f(in.σ²), in.ϵ, in.momentum, in.active) InstanceNorm(in.λ, f(in.β), f(in.γ), f(in.μ), f(in.σ²), in.ϵ, in.momentum)
_testmode!(in::InstanceNorm, test) = (in.active = !test)
function Base.show(io::IO, l::InstanceNorm) function Base.show(io::IO, l::InstanceNorm)
print(io, "InstanceNorm($(join(size(l.β), ", "))") print(io, "InstanceNorm($(join(size(l.β), ", "))")

View File

@ -1,5 +1,5 @@
import Adapt: adapt, adapt_storage import Adapt: adapt, adapt_storage
import .Zygote: IdSet import Zygote: IdSet
children(x) = () children(x) = ()
mapchildren(f, x) = x mapchildren(f, x) = x