From 934f0840b2bd20dd43bb563d62ebd3199a87b592 Mon Sep 17 00:00:00 2001 From: chengchingwen Date: Thu, 14 Mar 2019 21:51:28 +0800 Subject: [PATCH] change API to dims --- src/layers/normalise.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 9617b2c1..4af6a196 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -31,11 +31,13 @@ function Dropout(p) Dropout{typeof(p)}(p, true) end +_dropout_shape(s, dims...) = tuple((i ∈ dims ? 1 : si for (i, si) ∈ enumerate(s))...) + _dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0) -function (a::Dropout)(x, noise_shape=size(x)) +function (a::Dropout)(x, dims=0) a.active || return x - y = similar(x, noise_shape) + y = similar(x, _dropout_shape(size(x), dims...)) rand!(y) y .= _dropout_kernel.(y, a.p, 1 - a.p) return x .* y