fix dropout
This commit is contained in:
parent
094b38ac03
commit
1fc584102d
@ -2,6 +2,19 @@ istraining() = false
|
|||||||
|
|
||||||
@adjoint istraining() = true, _ -> nothing
|
@adjoint istraining() = true, _ -> nothing
|
||||||
|
|
||||||
|
_dropout_shape(s, ::Colon) = size(s)
|
||||||
|
_dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) ∈ enumerate(size(s)))...)
|
||||||
|
|
||||||
|
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
|
||||||
|
|
||||||
|
dropout(x, p; dims = :) = x
|
||||||
|
|
||||||
|
@adjoint function dropout(x, p; dims = :)
|
||||||
|
y = rand!(similar(x, _dropout_shape(x, dims)))
|
||||||
|
y .= _dropout_kernel.(y, p, 1 - p)
|
||||||
|
return x .* y, Δ -> (Δ .* y, nothing)
|
||||||
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Dropout(p, dims = :)
|
Dropout(p, dims = :)
|
||||||
|
|
||||||
@ -12,33 +25,17 @@ A Dropout layer. For each input, either sets that input to `0` (with probability
|
|||||||
|
|
||||||
Does nothing to the input once in [`testmode!`](@ref).
|
Does nothing to the input once in [`testmode!`](@ref).
|
||||||
"""
|
"""
|
||||||
mutable struct Dropout{F}
|
mutable struct Dropout{F,D}
|
||||||
p::F
|
p::F
|
||||||
dims::Union{Colon, Int, NTuple{N, Int} where N}
|
dims::D
|
||||||
end
|
end
|
||||||
|
|
||||||
function Dropout(p; dims = :)
|
function Dropout(p; dims = :)
|
||||||
@assert 0 ≤ p ≤ 1
|
@assert 0 ≤ p ≤ 1
|
||||||
Dropout{typeof(p)}(p, dims)
|
Dropout{typeof(p),typeof(dims)}(p, dims)
|
||||||
end
|
end
|
||||||
|
|
||||||
_dropout_shape(s, ::Colon) = size(s)
|
(a::Dropout)(x) = dropout(x, a.p; dims = a.dims)
|
||||||
_dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) ∈ enumerate(size(s)))...)
|
|
||||||
|
|
||||||
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
|
|
||||||
|
|
||||||
function dropout(x, p; dims = :)
|
|
||||||
istraining() || return x
|
|
||||||
y = similar(x, _dropout_shape(x, dims))
|
|
||||||
rand!(y)
|
|
||||||
y .= _dropout_kernel.(y, p, 1 - p)
|
|
||||||
return x .* y
|
|
||||||
end
|
|
||||||
|
|
||||||
function (a::Dropout)(x)
|
|
||||||
istraining() || return x
|
|
||||||
return dropout(x, a.p; dims = a.dims)
|
|
||||||
end
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
AlphaDropout(p)
|
AlphaDropout(p)
|
||||||
|
Loading…
Reference in New Issue
Block a user