commit
4c1b1eb18c
@ -30,3 +30,11 @@ leakyrelu
|
|||||||
elu
|
elu
|
||||||
swish
|
swish
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Normalisation & Regularisation
|
||||||
|
|
||||||
|
These layers don't affect the structure of the network but may improve training times or reduce overfitting.
|
||||||
|
|
||||||
|
```@docs
|
||||||
|
Dropout
|
||||||
|
```
|
||||||
|
@ -7,7 +7,7 @@ module Flux
|
|||||||
using Juno, Requires
|
using Juno, Requires
|
||||||
using Lazy: @forward
|
using Lazy: @forward
|
||||||
|
|
||||||
export Chain, Dense, RNN, LSTM,
|
export Chain, Dense, RNN, LSTM, Dropout,
|
||||||
SGD, ADAM, Momentum, Nesterov,
|
SGD, ADAM, Momentum, Nesterov,
|
||||||
param, params, mapleaves
|
param, params, mapleaves
|
||||||
|
|
||||||
@ -27,5 +27,6 @@ include("tree.jl")
|
|||||||
include("layers/stateless.jl")
|
include("layers/stateless.jl")
|
||||||
include("layers/basic.jl")
|
include("layers/basic.jl")
|
||||||
include("layers/recurrent.jl")
|
include("layers/recurrent.jl")
|
||||||
|
include("layers/normalisation.jl")
|
||||||
|
|
||||||
end # module
|
end # module
|
||||||
|
@ -27,7 +27,7 @@ end
|
|||||||
children(c::Chain) = c.layers
|
children(c::Chain) = c.layers
|
||||||
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
|
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
|
||||||
|
|
||||||
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
|
(c::Chain)(x) = foldl((x, m) -> m(x), x, c.layers)
|
||||||
|
|
||||||
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
|
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
|
||||||
|
|
||||||
|
45
src/layers/normalisation.jl
Normal file
45
src/layers/normalisation.jl
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
"""
|
||||||
|
testmode!(m)
|
||||||
|
testmode!(m, false)
|
||||||
|
|
||||||
|
Put layers like [`Dropout`](@ref) and `BatchNorm` into testing mode (or back to
|
||||||
|
training mode with `false`).
|
||||||
|
"""
|
||||||
|
function testmode!(m, val::Bool=true)
|
||||||
|
prefor(x -> _testmode!(x, val), m)
|
||||||
|
return m
|
||||||
|
end
|
||||||
|
|
||||||
|
_testmode!(m, test) = nothing
|
||||||
|
|
||||||
|
"""
|
||||||
|
Dropout(p)
|
||||||
|
|
||||||
|
A Dropout layer. For each input, either sets that input to `0` (with probability
|
||||||
|
`p`) or scales it by `1/(1-p)`. This is used as a regularisation, i.e. it
|
||||||
|
reduces overfitting during training.
|
||||||
|
|
||||||
|
Does nothing to the input once in [`testmode!`](@ref).
|
||||||
|
"""
|
||||||
|
mutable struct Dropout{F}
|
||||||
|
p::F
|
||||||
|
active::Bool
|
||||||
|
end
|
||||||
|
|
||||||
|
function Dropout(p)
|
||||||
|
@assert 0 ≤ p ≤ 1
|
||||||
|
Dropout{typeof(p)}(p, true)
|
||||||
|
end
|
||||||
|
|
||||||
|
function (a::Dropout)(x)
|
||||||
|
a.active || return x
|
||||||
|
y = similar(x)
|
||||||
|
rand!(y)
|
||||||
|
q = 1 - a.p
|
||||||
|
@inbounds for i=1:length(y)
|
||||||
|
y[i] = y[i] > a.p ? 1 / q : 0
|
||||||
|
end
|
||||||
|
return y .* x
|
||||||
|
end
|
||||||
|
|
||||||
|
_testmode!(a::Dropout, test) = (a.active = !test)
|
@ -56,6 +56,18 @@ Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
|
|||||||
|
|
||||||
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
|
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
|
||||||
|
|
||||||
|
value(x) = x
|
||||||
|
value(x::TrackedArray) = data(x)
|
||||||
|
value(x::TrackedScalar) = data(x)[]
|
||||||
|
|
||||||
|
Base.:(==)(x::TrackedArray, y) = value(x) == y
|
||||||
|
Base.:(==)(y, x::TrackedArray) = y == value(x)
|
||||||
|
Base.:(==)(x::TrackedArray, y::TrackedArray) = value(x) == value(x)
|
||||||
|
|
||||||
|
Base.isless(x::TrackedScalar, y) = isless(value(x), y)
|
||||||
|
Base.isless(x, y::TrackedScalar) = isless(x, value(y))
|
||||||
|
Base.isless(x::TrackedScalar, y::TrackedScalar) = isless(value(x), value(y))
|
||||||
|
|
||||||
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
|
Base.show(io::IO, ::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} =
|
||||||
print(io, "TrackedArray{…,$A}")
|
print(io, "TrackedArray{…,$A}")
|
||||||
|
|
||||||
|
28
test/layers/normalisation.jl
Normal file
28
test/layers/normalisation.jl
Normal file
@ -0,0 +1,28 @@
|
|||||||
|
using Flux: testmode!
|
||||||
|
|
||||||
|
@testset "Dropout" begin
|
||||||
|
x = [1.,2.,3.]
|
||||||
|
@test x == testmode!(Dropout(0.1))(x)
|
||||||
|
@test x == Dropout(0)(x)
|
||||||
|
@test zeros(x) == Dropout(1)(x)
|
||||||
|
|
||||||
|
x = rand(100)
|
||||||
|
m = Dropout(0.9)
|
||||||
|
y = m(x)
|
||||||
|
@test count(a->a==0, y) > 50
|
||||||
|
testmode!(m)
|
||||||
|
y = m(x)
|
||||||
|
@test count(a->a==0, y) == 0
|
||||||
|
testmode!(m, false)
|
||||||
|
y = m(x)
|
||||||
|
@test count(a->a==0, y) > 50
|
||||||
|
|
||||||
|
x = rand(100)
|
||||||
|
m = Chain(Dense(100,100),
|
||||||
|
Dropout(0.9))
|
||||||
|
y = m(x)
|
||||||
|
@test count(a->a == 0, y) > 50
|
||||||
|
testmode!(m)
|
||||||
|
y = m(x)
|
||||||
|
@test count(a->a == 0, y) == 0
|
||||||
|
end
|
@ -4,5 +4,6 @@ using Flux, Base.Test
|
|||||||
|
|
||||||
include("utils.jl")
|
include("utils.jl")
|
||||||
include("tracker.jl")
|
include("tracker.jl")
|
||||||
|
include("layers/normalisation.jl")
|
||||||
|
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user