diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 74905a36..54f5eb56 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -31,15 +31,14 @@ function Dropout(p) Dropout{typeof(p)}(p, true) end +_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0) + function (a::Dropout)(x) a.active || return x y = similar(x) rand!(y) - q = 1 - a.p - @inbounds for i=1:length(y) - y[i] = y[i] > a.p ? 1 / q : 0 - end - return y .* x + y .= _dropout_kernel.(y, a.p, 1 - a.p) + return x .* y end _testmode!(a::Dropout, test) = (a.active = !test)