add dropout

This commit is contained in:
CarloLucibello 2017-10-23 10:12:53 +02:00
parent 2a66545ef8
commit 2e1ed4c3fc
4 changed files with 67 additions and 3 deletions

View File

@ -7,9 +7,9 @@ module Flux
using Juno, Requires
using Lazy: @forward
export Chain, Dense, RNN, LSTM,
export Chain, Dense, RNN, LSTM, Dropout,
SGD, ADAM, Momentum, Nesterov,
param, params, mapleaves
param, params, mapleaves, setmode!
using NNlib
export σ, relu, leakyrelu, elu, swish, softmax

View File

@ -27,7 +27,7 @@ end
children(c::Chain) = c.layers
mapchildren(f, c::Chain) = Chain(f.(c.layers)...)
(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers)
(c::Chain)(x) = foldl((x, m) -> m(x), x, c.layers)
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
@ -78,3 +78,43 @@ function Base.show(io::IO, l::Dense)
l.σ == identity || print(io, ", ", l.σ)
print(io, ")")
end
"""
Dropout(p; mode=:train)
A Dropout layer. In `:train` mode sets input components `x[i]` to zero with
probability `p` and to `x[i]/(1-p)` with probability `(1-p)`.
In `:eval` mode it doesn't alter the input: `x == Dropout(p; mode=:eval)(x)`.
Change the mode with [`setmode!`](@ref).
"""
mutable struct Dropout{F}
p::F
mode::Symbol
end
Dropout(p::F; mode=:train) where {F} = Dropout{F}(p, mode)
function (a::Dropout)(x)
if a.mode == :eval
return x
else
if 0 < a.p < 1
y = similar(x)
rand!(y)
q = 1 - a.p
@inbounds for i=1:length(y)
y[i] = y[i] > a.p ? 1 / q : 0
end
return y .* x
elseif a.p == 0
return x
elseif a.p == 1
return zeros(x)
end
end
end
setmode!(a, mode::Symbol) = nothing
setmode!(c::Chain, mode::Symbol) = mapchildren(x->setmode!(x, mode), c)
setmode!(a::Dropout, mode::Symbol) = a.mode = mode

23
test/layers.jl Normal file
View File

@ -0,0 +1,23 @@
@testset "dropout" begin
x = [1.,2.,3.]
@test x === Dropout(0.1, mode=:eval)(x)
@test x === Dropout(0, mode=:train)(x)
@test all(zeros(x) .== Dropout(1, mode=:train)(x))
x = rand(100)
m = Dropout(0.9)
y = m(x)
@test count(a->a==0, y) > 50
setmode!(m, :eval)
y = m(x)
@test count(a->a==0, y) == 0
x = rand(100)
m = Chain(Dense(100,100),
Dropout(0.9))
y = m(x)
@test count(a->a.data[] == 0, y) > 50
setmode!(m, :eval)
y = m(x)
@test count(a->a.data[] == 0, y) == 0
end

View File

@ -4,5 +4,6 @@ using Flux, Base.Test
include("utils.jl")
include("tracker.jl")
include("layers.jl")
end