diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 082e651e..95599867 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -14,9 +14,10 @@ Does nothing to the input once in [`testmode!`](@ref). """ mutable struct Dropout{F} p::F - function Dropout(p) + dims::Union{Colon, Int, NTuple{N, Int} where N} + function Dropout(p; dims = :) @assert 0 ≤ p ≤ 1 - new{typeof(p)}(p) + Dropout{typeof(p)}(p, dims) end end