setmode! -> testmode!

This commit is contained in:
CarloLucibello 2017-10-23 16:23:29 +02:00
parent 86c7c9246e
commit 536ab3861d
4 changed files with 21 additions and 20 deletions

View File

@ -9,7 +9,7 @@ using Lazy: @forward
export Chain, Dense, RNN, LSTM, Dropout, export Chain, Dense, RNN, LSTM, Dropout,
SGD, ADAM, Momentum, Nesterov, SGD, ADAM, Momentum, Nesterov,
param, params, mapleaves, setmode! param, params, mapleaves, testmode!
using NNlib using NNlib
export σ, relu, leakyrelu, elu, swish, softmax export σ, relu, leakyrelu, elu, swish, softmax

View File

@ -81,22 +81,22 @@ end
""" """
Dropout(p; mode=:train) Dropout(p; testmode=false)
A Dropout layer. In `:train` mode sets input components `x[i]` to zero with A Dropout layer. If `testmode=false` mode sets input components `x[i]` to zero with
probability `p` and to `x[i]/(1-p)` with probability `(1-p)`. probability `p` and to `x[i]/(1-p)` with probability `(1-p)`.
In `:eval` mode it doesn't alter the input: `x == Dropout(p; mode=:eval)(x)`. In `testmode=true`it doesn't alter the input: `x == Dropout(p; mode=:eval)(x)`.
Change the mode with [`setmode!`](@ref). Change the mode with [`testmode!`](@ref).
""" """
mutable struct Dropout{F} mutable struct Dropout{F}
p::F p::F
mode::Symbol testmode::Bool
end end
Dropout(p::F; mode=:train) where {F} = Dropout{F}(p, mode) Dropout(p::F; testmode::Bool=false) where {F} = Dropout{F}(p, testmode)
function (a::Dropout)(x) function (a::Dropout)(x)
if a.mode == :eval if a.testmode
return x return x
else else
if 0 < a.p < 1 if 0 < a.p < 1
@ -116,12 +116,9 @@ function (a::Dropout)(x)
end end
""" """
setmode!(m, mode::Symbol) testmode!(m, val=true)
Change the mode of model `m` to `mode`. Possible values for `mode` are Set model `m` in test mode if `val=true`, and in training mode otherwise.
`:train` and `:eval`. This has an affect only if `m` contains [`Dropout`](@ref) or `BatchNorm` layers.
This has an affect only if `m` contains [`Dropout`](@ref) of `BatchNorm` layers.
""" """
setmode!(a, mode::Symbol) = nothing testmode!(m, val::Bool=true) = prefor(x -> x isa Dropout && (x.testmode = val), m)
setmode!(c::Chain, mode::Symbol) = mapchildren(x->setmode!(x, mode), c)
setmode!(a::Dropout, mode::Symbol) = a.mode = mode

View File

@ -41,6 +41,7 @@ TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
param(xs) = TrackedArray(AbstractFloat.(xs)) param(xs) = TrackedArray(AbstractFloat.(xs))
istracked(x::TrackedArray) = true istracked(x::TrackedArray) = true
data(x::TrackedArray) = x.data data(x::TrackedArray) = x.data
# data(x::TrackedScalar) = x.data[]
grad(x::TrackedArray) = x.grad grad(x::TrackedArray) = x.grad
# Fallthrough methods # Fallthrough methods

View File

@ -1,23 +1,26 @@
@testset "dropout" begin @testset "dropout" begin
x = [1.,2.,3.] x = [1.,2.,3.]
@test x === Dropout(0.1, mode=:eval)(x) @test x === Dropout(0.1, testmode=true)(x)
@test x === Dropout(0, mode=:train)(x) @test x === Dropout(0, testmode=false)(x)
@test all(zeros(x) .== Dropout(1, mode=:train)(x)) @test all(zeros(x) .== Dropout(1, testmode=false)(x))
x = rand(100) x = rand(100)
m = Dropout(0.9) m = Dropout(0.9)
y = m(x) y = m(x)
@test count(a->a==0, y) > 50 @test count(a->a==0, y) > 50
setmode!(m, :eval) testmode!(m)
y = m(x) y = m(x)
@test count(a->a==0, y) == 0 @test count(a->a==0, y) == 0
testmode!(m, false)
y = m(x)
@test count(a->a==0, y) > 50
x = rand(100) x = rand(100)
m = Chain(Dense(100,100), m = Chain(Dense(100,100),
Dropout(0.9)) Dropout(0.9))
y = m(x) y = m(x)
@test count(a->a == 0, y) > 50 @test count(a->a == 0, y) > 50
setmode!(m, :eval) testmode!(m)
y = m(x) y = m(x)
@test count(a->a == 0, y) == 0 @test count(a->a == 0, y) == 0
end end