noise shape for dropout

This commit is contained in:
chengchingwen 2019-01-22 23:51:38 +08:00
parent 4be08fe194
commit 06003b72c7
2 changed files with 11 additions and 2 deletions

View File

@ -33,9 +33,9 @@ end
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
function (a::Dropout)(x)
function (a::Dropout)(x, noise_shape=size(x))
a.active || return x
y = similar(x)
y = similar(x, noise_shape)
rand!(y)
y .= _dropout_kernel.(y, a.p, 1 - a.p)
return x .* y

View File

@ -26,6 +26,15 @@ using Flux.Tracker: data
testmode!(m)
y = m(x)
@test count(a->a == 0, y) == 0
x = rand(100, 50)
m = Dropout(0.5)
y = m(x, (100, 1))
c = map(i->count(a->a==0, @view y[:, i]), 1:50)
@test minimum(c) == maximum(c)
y = m(x, (1, 50))
c = map(i->count(a->a==0, @view y[i, :]), 1:100)
@test minimum(c) == maximum(c)
end
@testset "BatchNorm" begin