diff --git a/src/Flux.jl b/src/Flux.jl index e4f170f2..daeaa9ac 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -7,9 +7,9 @@ module Flux using Juno, Requires using Lazy: @forward -export Chain, Dense, RNN, LSTM, +export Chain, Dense, RNN, LSTM, Dropout, SGD, ADAM, Momentum, Nesterov, - param, params, mapleaves + param, params, mapleaves, setmode! using NNlib export σ, relu, leakyrelu, elu, swish, softmax diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 9c8b1016..088cf1e1 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -27,7 +27,7 @@ end children(c::Chain) = c.layers mapchildren(f, c::Chain) = Chain(f.(c.layers)...) -(s::Chain)(x) = foldl((x, m) -> m(x), x, s.layers) +(c::Chain)(x) = foldl((x, m) -> m(x), x, c.layers) Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...) @@ -78,3 +78,43 @@ function Base.show(io::IO, l::Dense) l.σ == identity || print(io, ", ", l.σ) print(io, ")") end + + +""" + Dropout(p; mode=:train) + +A Dropout layer. In `:train` mode sets input components `x[i]` to zero with +probability `p` and to `x[i]/(1-p)` with probability `(1-p)`. + +In `:eval` mode it doesn't alter the input: `x == Dropout(p; mode=:eval)(x)`. +Change the mode with [`setmode!`](@ref). +""" +mutable struct Dropout{F} + p::F + mode::Symbol +end +Dropout(p::F; mode=:train) where {F} = Dropout{F}(p, mode) + +function (a::Dropout)(x) + if a.mode == :eval + return x + else + if 0 < a.p < 1 + y = similar(x) + rand!(y) + q = 1 - a.p + @inbounds for i=1:length(y) + y[i] = y[i] > a.p ? 1 / q : 0 + end + return y .* x + elseif a.p == 0 + return x + elseif a.p == 1 + return zeros(x) + end + end +end + +setmode!(a, mode::Symbol) = nothing +setmode!(c::Chain, mode::Symbol) = mapchildren(x->setmode!(x, mode), c) +setmode!(a::Dropout, mode::Symbol) = a.mode = mode diff --git a/test/layers.jl b/test/layers.jl new file mode 100644 index 00000000..ead9c343 --- /dev/null +++ b/test/layers.jl @@ -0,0 +1,23 @@ +@testset "dropout" begin + x = [1.,2.,3.] + @test x === Dropout(0.1, mode=:eval)(x) + @test x === Dropout(0, mode=:train)(x) + @test all(zeros(x) .== Dropout(1, mode=:train)(x)) + + x = rand(100) + m = Dropout(0.9) + y = m(x) + @test count(a->a==0, y) > 50 + setmode!(m, :eval) + y = m(x) + @test count(a->a==0, y) == 0 + + x = rand(100) + m = Chain(Dense(100,100), + Dropout(0.9)) + y = m(x) + @test count(a->a.data[] == 0, y) > 50 + setmode!(m, :eval) + y = m(x) + @test count(a->a.data[] == 0, y) == 0 +end diff --git a/test/runtests.jl b/test/runtests.jl index 2ab0e447..b7b838df 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,5 +4,6 @@ using Flux, Base.Test include("utils.jl") include("tracker.jl") +include("layers.jl") end