noise shape for dropout
This commit is contained in:
parent
4be08fe194
commit
06003b72c7
@ -33,9 +33,9 @@ end
|
|||||||
|
|
||||||
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
|
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
|
||||||
|
|
||||||
function (a::Dropout)(x)
|
function (a::Dropout)(x, noise_shape=size(x))
|
||||||
a.active || return x
|
a.active || return x
|
||||||
y = similar(x)
|
y = similar(x, noise_shape)
|
||||||
rand!(y)
|
rand!(y)
|
||||||
y .= _dropout_kernel.(y, a.p, 1 - a.p)
|
y .= _dropout_kernel.(y, a.p, 1 - a.p)
|
||||||
return x .* y
|
return x .* y
|
||||||
|
@ -26,6 +26,15 @@ using Flux.Tracker: data
|
|||||||
testmode!(m)
|
testmode!(m)
|
||||||
y = m(x)
|
y = m(x)
|
||||||
@test count(a->a == 0, y) == 0
|
@test count(a->a == 0, y) == 0
|
||||||
|
|
||||||
|
x = rand(100, 50)
|
||||||
|
m = Dropout(0.5)
|
||||||
|
y = m(x, (100, 1))
|
||||||
|
c = map(i->count(a->a==0, @view y[:, i]), 1:50)
|
||||||
|
@test minimum(c) == maximum(c)
|
||||||
|
y = m(x, (1, 50))
|
||||||
|
c = map(i->count(a->a==0, @view y[i, :]), 1:100)
|
||||||
|
@test minimum(c) == maximum(c)
|
||||||
end
|
end
|
||||||
|
|
||||||
@testset "BatchNorm" begin
|
@testset "BatchNorm" begin
|
||||||
|
Loading…
Reference in New Issue
Block a user