reorganise
This commit is contained in:
parent
711ea09d99
commit
cf6b930f63
|
@ -30,3 +30,11 @@ leakyrelu
|
|||
elu
|
||||
swish
|
||||
```
|
||||
|
||||
## Normalisation & Regularisation
|
||||
|
||||
These layers don't affect the structure of the network but may improve training times or reduce overfitting.
|
||||
|
||||
```@docs
|
||||
Dropout
|
||||
```
|
||||
|
|
|
@ -9,7 +9,7 @@ using Lazy: @forward
|
|||
|
||||
export Chain, Dense, RNN, LSTM, Dropout,
|
||||
SGD, ADAM, Momentum, Nesterov,
|
||||
param, params, mapleaves, testmode!
|
||||
param, params, mapleaves
|
||||
|
||||
using NNlib
|
||||
export σ, relu, leakyrelu, elu, swish, softmax
|
||||
|
@ -27,5 +27,6 @@ include("tree.jl")
|
|||
include("layers/stateless.jl")
|
||||
include("layers/basic.jl")
|
||||
include("layers/recurrent.jl")
|
||||
include("layers/normalisation.jl")
|
||||
|
||||
end # module
|
||||
|
|
|
@ -78,47 +78,3 @@ function Base.show(io::IO, l::Dense)
|
|||
l.σ == identity || print(io, ", ", l.σ)
|
||||
print(io, ")")
|
||||
end
|
||||
|
||||
|
||||
"""
|
||||
Dropout(p; testmode=false)
|
||||
|
||||
A Dropout layer. If `testmode=false` mode sets input components `x[i]` to zero with
|
||||
probability `p` and to `x[i]/(1-p)` with probability `(1-p)`.
|
||||
|
||||
In `testmode=true`it doesn't alter the input: `x == Dropout(p; mode=:eval)(x)`.
|
||||
Change the mode with [`testmode!`](@ref).
|
||||
"""
|
||||
mutable struct Dropout{F}
|
||||
p::F
|
||||
testmode::Bool
|
||||
end
|
||||
Dropout(p::F; testmode::Bool=false) where {F} = Dropout{F}(p, testmode)
|
||||
|
||||
function (a::Dropout)(x)
|
||||
if a.testmode
|
||||
return x
|
||||
else
|
||||
if 0 < a.p < 1
|
||||
y = similar(x)
|
||||
rand!(y)
|
||||
q = 1 - a.p
|
||||
@inbounds for i=1:length(y)
|
||||
y[i] = y[i] > a.p ? 1 / q : 0
|
||||
end
|
||||
return y .* x
|
||||
elseif a.p == 0
|
||||
return x
|
||||
elseif a.p == 1
|
||||
return zeros(x)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
"""
|
||||
testmode!(m, val=true)
|
||||
|
||||
Set model `m` in test mode if `val=true`, and in training mode otherwise.
|
||||
This has an affect only if `m` contains [`Dropout`](@ref) or `BatchNorm` layers.
|
||||
"""
|
||||
testmode!(m, val::Bool=true) = prefor(x -> :testmode ∈ fieldnames(x) && (x.testmode = val), m)
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
"""
|
||||
testmode!(m)
|
||||
testmode!(m, false)
|
||||
|
||||
Put layers like [`Dropout`](@ref) and `BatchNorm` into testing mode (or back to
|
||||
training mode with `false`).
|
||||
"""
|
||||
function testmode!(m, val::Bool=true)
|
||||
prefor(x -> _testmode!(x, val), m)
|
||||
return m
|
||||
end
|
||||
|
||||
_testmode!(m, test) = nothing
|
||||
|
||||
"""
|
||||
Dropout(p)
|
||||
|
||||
A Dropout layer. For each input, either sets that input to `0` (with probability
|
||||
`p`) or scales it by `1/(1-p)`. This is used as a regularisation, i.e. it
|
||||
reduces overfitting during training.
|
||||
|
||||
Does nothing to the input once in [`testmode!`](@ref).
|
||||
"""
|
||||
mutable struct Dropout{F}
|
||||
p::F
|
||||
active::Bool
|
||||
end
|
||||
|
||||
function Dropout(p)
|
||||
@assert 0 ≤ p ≤ 1
|
||||
Dropout{typeof(p)}(p, true)
|
||||
end
|
||||
|
||||
function (a::Dropout)(x)
|
||||
a.active || return x
|
||||
y = similar(x)
|
||||
rand!(y)
|
||||
q = 1 - a.p
|
||||
@inbounds for i=1:length(y)
|
||||
y[i] = y[i] > a.p ? 1 / q : 0
|
||||
end
|
||||
return y .* x
|
||||
end
|
||||
|
||||
_testmode!(a::Dropout, test) = (a.active = !test)
|
|
@ -1,8 +1,10 @@
|
|||
@testset "dropout" begin
|
||||
using Flux: testmode!
|
||||
|
||||
@testset "Dropout" begin
|
||||
x = [1.,2.,3.]
|
||||
@test x === Dropout(0.1, testmode=true)(x)
|
||||
@test x === Dropout(0, testmode=false)(x)
|
||||
@test all(zeros(x) .== Dropout(1, testmode=false)(x))
|
||||
@test x == testmode!(Dropout(0.1))(x)
|
||||
@test x == Dropout(0)(x)
|
||||
@test zeros(x) == Dropout(1)(x)
|
||||
|
||||
x = rand(100)
|
||||
m = Dropout(0.9)
|
|
@ -4,6 +4,6 @@ using Flux, Base.Test
|
|||
|
||||
include("utils.jl")
|
||||
include("tracker.jl")
|
||||
include("layers.jl")
|
||||
include("layers/normalisation.jl")
|
||||
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue