diff --git a/src/Flux.jl b/src/Flux.jl index ef43edeb..a4f8cd93 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -5,15 +5,13 @@ module Flux using Base: tail using MacroTools, Juno, Requires, Reexport, Statistics, Random using MacroTools: @forward +@reexport using NNlib +using Zygote: Params, @adjoint, gradient export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, ConvTranspose, MaxPool, MeanPool, DepthwiseConv, Dropout, AlphaDropout, LayerNorm, BatchNorm, InstanceNorm, GroupNorm, params, mapleaves, cpu, gpu, f32, f64 -@reexport using NNlib - -using Zygote - include("optimise/Optimise.jl") using .Optimise using .Optimise: @epochs diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 4ee6b758..9528cec4 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -1,16 +1,6 @@ -""" - testmode!(m) - testmode!(m, false) +istraining() = false -Put layers like [`Dropout`](@ref) and [`BatchNorm`](@ref) into testing mode -(or back to training mode with `false`). -""" -function testmode!(m, val::Bool=true) - prefor(x -> _testmode!(x, val), m) - return m -end - -_testmode!(m, test) = nothing +@adjoint istraining() = true, _ -> nothing """ Dropout(p) @@ -23,44 +13,38 @@ Does nothing to the input once in [`testmode!`](@ref). """ mutable struct Dropout{F} p::F - active::Bool -end - -function Dropout(p) - @assert 0 ≤ p ≤ 1 - Dropout{typeof(p)}(p, true) + function Dropout(p) + @assert 0 ≤ p ≤ 1 + new{typeof(p)}(p) + end end _dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0) function (a::Dropout)(x) - a.active || return x + istraining() || return x y = similar(x) rand!(y) y .= _dropout_kernel.(y, a.p, 1 - a.p) return x .* y end -_testmode!(a::Dropout, test) = (a.active = !test) - """ AlphaDropout(p) -A dropout layer. It is used in Self-Normalizing Neural Networks. +A dropout layer. It is used in Self-Normalizing Neural Networks. (https://papers.nips.cc/paper/6698-self-normalizing-neural-networks.pdf) The AlphaDropout layer ensures that mean and variance of activations remains the same as before. """ mutable struct AlphaDropout{F} p::F - active::Bool -end - -function AlphaDropout(p) - @assert 0 ≤ p ≤ 1 - AlphaDropout(p,true) + function AlphaDropout(p) + @assert 0 ≤ p ≤ 1 + new{typeof(p)}(p) + end end function (a::AlphaDropout)(x) - a.active || return x + istraining() || return x λ = eltype(x)(1.0507009873554804934193349852946) α = eltype(x)(1.6732632423543772848170429916717) α1 = eltype(x)(-λ*α) @@ -72,8 +56,6 @@ function (a::AlphaDropout)(x) return x end -_testmode!(a::AlphaDropout, test) = (a.active = !test) - """ LayerNorm(h::Integer) @@ -133,13 +115,12 @@ mutable struct BatchNorm{F,V,W,N} σ²::W # moving std ϵ::N momentum::N - active::Bool end BatchNorm(chs::Integer, λ = identity; initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) = BatchNorm(λ, initβ(chs), initγ(chs), - zeros(chs), ones(chs), ϵ, momentum, true) + zeros(chs), ones(chs), ϵ, momentum) function (BN::BatchNorm)(x) size(x, ndims(x)-1) == length(BN.β) || @@ -151,7 +132,7 @@ function (BN::BatchNorm)(x) m = prod(size(x)[1:end-2]) * size(x)[end] γ = reshape(BN.γ, affine_shape...) β = reshape(BN.β, affine_shape...) - if !BN.active + if !istraining() μ = reshape(BN.μ, affine_shape...) σ² = reshape(BN.σ², affine_shape...) ϵ = BN.ϵ @@ -174,12 +155,10 @@ function (BN::BatchNorm)(x) end children(BN::BatchNorm) = - (BN.λ, BN.β, BN.γ, BN.μ, BN.σ², BN.ϵ, BN.momentum, BN.active) + (BN.λ, BN.β, BN.γ, BN.μ, BN.σ², BN.ϵ, BN.momentum) mapchildren(f, BN::BatchNorm) = # e.g. mapchildren(cu, BN) - BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ²), BN.ϵ, BN.momentum, BN.active) - -_testmode!(BN::BatchNorm, test) = (BN.active = !test) + BatchNorm(BN.λ, f(BN.β), f(BN.γ), f(BN.μ), f(BN.σ²), BN.ϵ, BN.momentum) function Base.show(io::IO, l::BatchNorm) print(io, "BatchNorm($(join(size(l.β), ", "))") @@ -226,13 +205,12 @@ mutable struct InstanceNorm{F,V,W,N} σ²::W # moving std ϵ::N momentum::N - active::Bool end InstanceNorm(chs::Integer, λ = identity; initβ = (i) -> zeros(Float32, i), initγ = (i) -> ones(Float32, i), ϵ = 1f-5, momentum = 0.1f0) = InstanceNorm(λ, initβ(chs), initγ(chs), - zeros(chs), ones(chs), ϵ, momentum, true) + zeros(chs), ones(chs), ϵ, momentum) function (in::InstanceNorm)(x) size(x, ndims(x)-1) == length(in.β) || @@ -249,7 +227,7 @@ function (in::InstanceNorm)(x) m = prod(size(x)[1:end-2]) γ, β = expand_inst(in.γ, affine_shape), expand_inst(in.β, affine_shape) - if !in.active + if !istraining() μ = expand_inst(in.μ, affine_shape) σ² = expand_inst(in.σ², affine_shape) ϵ = in.ϵ @@ -274,12 +252,10 @@ function (in::InstanceNorm)(x) end children(in::InstanceNorm) = - (in.λ, in.β, in.γ, in.μ, in.σ², in.ϵ, in.momentum, in.active) + (in.λ, in.β, in.γ, in.μ, in.σ², in.ϵ, in.momentum) mapchildren(f, in::InstanceNorm) = # e.g. mapchildren(cu, in) - InstanceNorm(in.λ, f(in.β), f(in.γ), f(in.μ), f(in.σ²), in.ϵ, in.momentum, in.active) - -_testmode!(in::InstanceNorm, test) = (in.active = !test) + InstanceNorm(in.λ, f(in.β), f(in.γ), f(in.μ), f(in.σ²), in.ϵ, in.momentum) function Base.show(io::IO, l::InstanceNorm) print(io, "InstanceNorm($(join(size(l.β), ", "))") diff --git a/src/treelike.jl b/src/treelike.jl index 6500c644..6392bbbb 100644 --- a/src/treelike.jl +++ b/src/treelike.jl @@ -1,5 +1,5 @@ import Adapt: adapt, adapt_storage -import .Zygote: IdSet +import Zygote: IdSet children(x) = () mapchildren(f, x) = x