change `dims` as unbroadcasted dims and keyword argument

This commit is contained in:
chengchingwen 2019-04-05 01:19:20 +08:00
parent 59da68b4d9
commit 261235311c
2 changed files with 8 additions and 7 deletions

View File

@ -31,13 +31,14 @@ function Dropout(p)
Dropout{typeof(p)}(p, true)
end
_dropout_shape(s, dims...) = tuple((i dims ? 1 : si for (i, si) enumerate(s))...)
_dropout_shape(s, ::Colon) = size(s)
_dropout_shape(s, dims) = tuple((i dims ? 1 : si for (i, si) enumerate(size(s)))...)
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
function (a::Dropout)(x, dims=0)
function (a::Dropout)(x; dims = :)
a.active || return x
y = similar(x, _dropout_shape(size(x), dims...))
y = similar(x, _dropout_shape(x, dims))
rand!(y)
y .= _dropout_kernel.(y, a.p, 1 - a.p)
return x .* y

View File

@ -29,12 +29,12 @@ using Flux.Tracker: data
x = rand(100, 50)
m = Dropout(0.5)
y = m(x, 2)
c = map(i->count(a->a==0, @view y[:, i]), 1:50)
@test minimum(c) == maximum(c)
y = m(x, 1)
y = m(x; dims=2)
c = map(i->count(a->a==0, @view y[i, :]), 1:100)
@test minimum(c) == maximum(c)
y = m(x; dims=1)
c = map(i->count(a->a==0, @view y[:, i]), 1:50)
@test minimum(c) == maximum(c)
end
@testset "BatchNorm" begin