diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 1adc3050..3755f3fc 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -29,7 +29,7 @@ _dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0) function dropout(x, p; dims = :) istraining() || return x - y = similar(x) + y = similar(x, _dropout_shape(x, dims)) rand!(y) y .= _dropout_kernel.(y, p, 1 - p) return x .* y