change API to dims

This commit is contained in:
chengchingwen 2019-03-14 21:51:28 +08:00
parent 06003b72c7
commit 934f0840b2

View File

@ -31,11 +31,13 @@ function Dropout(p)
Dropout{typeof(p)}(p, true)
end
_dropout_shape(s, dims...) = tuple((i dims ? 1 : si for (i, si) enumerate(s))...)
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
function (a::Dropout)(x, noise_shape=size(x))
function (a::Dropout)(x, dims=0)
a.active || return x
y = similar(x, noise_shape)
y = similar(x, _dropout_shape(size(x), dims...))
rand!(y)
y .= _dropout_kernel.(y, a.p, 1 - a.p)
return x .* y