change API to dims
This commit is contained in:
parent
06003b72c7
commit
934f0840b2
@ -31,11 +31,13 @@ function Dropout(p)
|
|||||||
Dropout{typeof(p)}(p, true)
|
Dropout{typeof(p)}(p, true)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
_dropout_shape(s, dims...) = tuple((i ∈ dims ? 1 : si for (i, si) ∈ enumerate(s))...)
|
||||||
|
|
||||||
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
|
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
|
||||||
|
|
||||||
function (a::Dropout)(x, noise_shape=size(x))
|
function (a::Dropout)(x, dims=0)
|
||||||
a.active || return x
|
a.active || return x
|
||||||
y = similar(x, noise_shape)
|
y = similar(x, _dropout_shape(size(x), dims...))
|
||||||
rand!(y)
|
rand!(y)
|
||||||
y .= _dropout_kernel.(y, a.p, 1 - a.p)
|
y .= _dropout_kernel.(y, a.p, 1 - a.p)
|
||||||
return x .* y
|
return x .* y
|
||||||
|
Loading…
Reference in New Issue
Block a user