Use broadcast for dropout

Should be fast enough on GPU now that it's not going to be an optimization target again for a while. Hopefully isn't meaningfully slower on CPU?
This commit is contained in:
James Bradbury 2018-05-20 04:04:33 -07:00 committed by GitHub
parent e92f840510
commit af12f006f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -31,15 +31,14 @@ function Dropout(p)
Dropout{typeof(p)}(p, true)
end
_dropout_kernel(y::T, p, q) where {T} = y > p ? T(1 / q) : T(0)
function (a::Dropout)(x)
a.active || return x
y = similar(x)
rand!(y)
q = 1 - a.p
@inbounds for i=1:length(y)
y[i] = y[i] > a.p ? 1 / q : 0
end
return y .* x
y .= _dropout_kernel.(y, a.p, 1 - a.p)
return x .* y
end
_testmode!(a::Dropout, test) = (a.active = !test)