noise shape for dropout
This commit is contained in:
parent
4be08fe194
commit
06003b72c7
|
@ -33,9 +33,9 @@ end
|
|||
|
||||
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
|
||||
|
||||
function (a::Dropout)(x)
|
||||
function (a::Dropout)(x, noise_shape=size(x))
|
||||
a.active || return x
|
||||
y = similar(x)
|
||||
y = similar(x, noise_shape)
|
||||
rand!(y)
|
||||
y .= _dropout_kernel.(y, a.p, 1 - a.p)
|
||||
return x .* y
|
||||
|
|
|
@ -26,6 +26,15 @@ using Flux.Tracker: data
|
|||
testmode!(m)
|
||||
y = m(x)
|
||||
@test count(a->a == 0, y) == 0
|
||||
|
||||
x = rand(100, 50)
|
||||
m = Dropout(0.5)
|
||||
y = m(x, (100, 1))
|
||||
c = map(i->count(a->a==0, @view y[:, i]), 1:50)
|
||||
@test minimum(c) == maximum(c)
|
||||
y = m(x, (1, 50))
|
||||
c = map(i->count(a->a==0, @view y[i, :]), 1:100)
|
||||
@test minimum(c) == maximum(c)
|
||||
end
|
||||
|
||||
@testset "BatchNorm" begin
|
||||
|
|
Loading…
Reference in New Issue