make dims as field of Dropout

This commit is contained in:
chengchingwen 2019-05-10 23:45:50 +08:00
parent 261235311c
commit 5c5140683c
2 changed files with 27 additions and 11 deletions

View File

@ -13,22 +13,24 @@ end
_testmode!(m, test) = nothing _testmode!(m, test) = nothing
""" """
Dropout(p) Dropout(p, dims = :)
A Dropout layer. For each input, either sets that input to `0` (with probability A Dropout layer. For each input, either sets that input to `0` (with probability
`p`) or scales it by `1/(1-p)`. This is used as a regularisation, i.e. it `p`) or scales it by `1/(1-p)`. The `dims` argument is to specified the unbroadcasted
reduces overfitting during training. dimensions, i.e. `dims=1` does dropout along columns and `dims=2` along rows. This is
used as a regularisation, i.e. it reduces overfitting during training. see also [`dropout`](@ref).
Does nothing to the input once in [`testmode!`](@ref). Does nothing to the input once in [`testmode!`](@ref).
""" """
mutable struct Dropout{F} mutable struct Dropout{F}
p::F p::F
dims::Union{Colon, Int, NTuple{N, Int} where N}
active::Bool active::Bool
end end
function Dropout(p) function Dropout(p; dims = :)
@assert 0 p 1 @assert 0 p 1
Dropout{typeof(p)}(p, true) Dropout{typeof(p)}(p, dims, true)
end end
_dropout_shape(s, ::Colon) = size(s) _dropout_shape(s, ::Colon) = size(s)
@ -36,14 +38,27 @@ _dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) ∈ enumerate(s
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0) _dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
function (a::Dropout)(x; dims = :)
a.active || return x """
dropout(x, p; dims = :)
The dropout function. For each input, either sets that input to `0` (with probability
`p`) or scales it by `1/(1-p)`. The `dims` argument is to specified the unbroadcasted
dimensions, i.e. `dims=1` does dropout along columns and `dims=2` along rows. This is
used as a regularisation, i.e. it reduces overfitting during training.
"""
function dropout(x, p; dims = :)
y = similar(x, _dropout_shape(x, dims)) y = similar(x, _dropout_shape(x, dims))
rand!(y) rand!(y)
y .= _dropout_kernel.(y, a.p, 1 - a.p) y .= _dropout_kernel.(y, p, 1 - p)
return x .* y return x .* y
end end
function (a::Dropout)(x)
a.active || return x
return dropout(x, a.p; dims = a.dims)
end
_testmode!(a::Dropout, test) = (a.active = !test) _testmode!(a::Dropout, test) = (a.active = !test)
""" """

View File

@ -28,11 +28,12 @@ using Flux.Tracker: data
@test count(a->a == 0, y) == 0 @test count(a->a == 0, y) == 0
x = rand(100, 50) x = rand(100, 50)
m = Dropout(0.5) m = Dropout(0.5, dims = 2)
y = m(x; dims=2) y = m(x)
c = map(i->count(a->a==0, @view y[i, :]), 1:100) c = map(i->count(a->a==0, @view y[i, :]), 1:100)
@test minimum(c) == maximum(c) @test minimum(c) == maximum(c)
y = m(x; dims=1) m = Dropout(0.5, dims = 1)
y = m(x)
c = map(i->count(a->a==0, @view y[:, i]), 1:50) c = map(i->count(a->a==0, @view y[:, i]), 1:50)
@test minimum(c) == maximum(c) @test minimum(c) == maximum(c)
end end