Made sure Gradients are not lost.
This commit is contained in:
parent
922e9c9bc2
commit
29b853e0bb
@ -63,13 +63,11 @@ function (a::AlphaDropout)(x)
|
|||||||
a.active || return x
|
a.active || return x
|
||||||
α = -1.75813631
|
α = -1.75813631
|
||||||
noise = randn(Float64, size(x.data))
|
noise = randn(Float64, size(x.data))
|
||||||
y = collect(x)
|
x.data .= x.data .* (noise .> (1 - a.p)) + α .* (noise .<= (1 - a.p))
|
||||||
y .= y .* (noise .> (1 - a.p)) + α .* (noise .<= (1 - a.p))
|
|
||||||
A = (a.p + a.p * (1 - a.p) * α ^ 2)^0.5
|
A = (a.p + a.p * (1 - a.p) * α ^ 2)^0.5
|
||||||
B = -A * α * (1 - a.p)
|
B = -A * α * (1 - a.p)
|
||||||
y .= A .* y .+ B
|
x.data .= A .* x.data .+ B
|
||||||
x1 = param(y)
|
return x
|
||||||
return x1
|
|
||||||
end
|
end
|
||||||
|
|
||||||
_testmode!(a::AlphaDropout, test) = (a.active = !test)
|
_testmode!(a::AlphaDropout, test) = (a.active = !test)
|
||||||
|
Loading…
Reference in New Issue
Block a user