setmode! -> testmode!
This commit is contained in:
parent
86c7c9246e
commit
536ab3861d
@ -9,7 +9,7 @@ using Lazy: @forward
|
|||||||
|
|
||||||
export Chain, Dense, RNN, LSTM, Dropout,
|
export Chain, Dense, RNN, LSTM, Dropout,
|
||||||
SGD, ADAM, Momentum, Nesterov,
|
SGD, ADAM, Momentum, Nesterov,
|
||||||
param, params, mapleaves, setmode!
|
param, params, mapleaves, testmode!
|
||||||
|
|
||||||
using NNlib
|
using NNlib
|
||||||
export σ, relu, leakyrelu, elu, swish, softmax
|
export σ, relu, leakyrelu, elu, swish, softmax
|
||||||
|
@ -81,22 +81,22 @@ end
|
|||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Dropout(p; mode=:train)
|
Dropout(p; testmode=false)
|
||||||
|
|
||||||
A Dropout layer. In `:train` mode sets input components `x[i]` to zero with
|
A Dropout layer. If `testmode=false` mode sets input components `x[i]` to zero with
|
||||||
probability `p` and to `x[i]/(1-p)` with probability `(1-p)`.
|
probability `p` and to `x[i]/(1-p)` with probability `(1-p)`.
|
||||||
|
|
||||||
In `:eval` mode it doesn't alter the input: `x == Dropout(p; mode=:eval)(x)`.
|
In `testmode=true`it doesn't alter the input: `x == Dropout(p; mode=:eval)(x)`.
|
||||||
Change the mode with [`setmode!`](@ref).
|
Change the mode with [`testmode!`](@ref).
|
||||||
"""
|
"""
|
||||||
mutable struct Dropout{F}
|
mutable struct Dropout{F}
|
||||||
p::F
|
p::F
|
||||||
mode::Symbol
|
testmode::Bool
|
||||||
end
|
end
|
||||||
Dropout(p::F; mode=:train) where {F} = Dropout{F}(p, mode)
|
Dropout(p::F; testmode::Bool=false) where {F} = Dropout{F}(p, testmode)
|
||||||
|
|
||||||
function (a::Dropout)(x)
|
function (a::Dropout)(x)
|
||||||
if a.mode == :eval
|
if a.testmode
|
||||||
return x
|
return x
|
||||||
else
|
else
|
||||||
if 0 < a.p < 1
|
if 0 < a.p < 1
|
||||||
@ -116,12 +116,9 @@ function (a::Dropout)(x)
|
|||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
setmode!(m, mode::Symbol)
|
testmode!(m, val=true)
|
||||||
|
|
||||||
Change the mode of model `m` to `mode`. Possible values for `mode` are
|
Set model `m` in test mode if `val=true`, and in training mode otherwise.
|
||||||
`:train` and `:eval`.
|
This has an affect only if `m` contains [`Dropout`](@ref) or `BatchNorm` layers.
|
||||||
This has an affect only if `m` contains [`Dropout`](@ref) of `BatchNorm` layers.
|
|
||||||
"""
|
"""
|
||||||
setmode!(a, mode::Symbol) = nothing
|
testmode!(m, val::Bool=true) = prefor(x -> x isa Dropout && (x.testmode = val), m)
|
||||||
setmode!(c::Chain, mode::Symbol) = mapchildren(x->setmode!(x, mode), c)
|
|
||||||
setmode!(a::Dropout, mode::Symbol) = a.mode = mode
|
|
||||||
|
@ -41,6 +41,7 @@ TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
|
|||||||
param(xs) = TrackedArray(AbstractFloat.(xs))
|
param(xs) = TrackedArray(AbstractFloat.(xs))
|
||||||
istracked(x::TrackedArray) = true
|
istracked(x::TrackedArray) = true
|
||||||
data(x::TrackedArray) = x.data
|
data(x::TrackedArray) = x.data
|
||||||
|
# data(x::TrackedScalar) = x.data[]
|
||||||
grad(x::TrackedArray) = x.grad
|
grad(x::TrackedArray) = x.grad
|
||||||
|
|
||||||
# Fallthrough methods
|
# Fallthrough methods
|
||||||
|
@ -1,23 +1,26 @@
|
|||||||
@testset "dropout" begin
|
@testset "dropout" begin
|
||||||
x = [1.,2.,3.]
|
x = [1.,2.,3.]
|
||||||
@test x === Dropout(0.1, mode=:eval)(x)
|
@test x === Dropout(0.1, testmode=true)(x)
|
||||||
@test x === Dropout(0, mode=:train)(x)
|
@test x === Dropout(0, testmode=false)(x)
|
||||||
@test all(zeros(x) .== Dropout(1, mode=:train)(x))
|
@test all(zeros(x) .== Dropout(1, testmode=false)(x))
|
||||||
|
|
||||||
x = rand(100)
|
x = rand(100)
|
||||||
m = Dropout(0.9)
|
m = Dropout(0.9)
|
||||||
y = m(x)
|
y = m(x)
|
||||||
@test count(a->a==0, y) > 50
|
@test count(a->a==0, y) > 50
|
||||||
setmode!(m, :eval)
|
testmode!(m)
|
||||||
y = m(x)
|
y = m(x)
|
||||||
@test count(a->a==0, y) == 0
|
@test count(a->a==0, y) == 0
|
||||||
|
testmode!(m, false)
|
||||||
|
y = m(x)
|
||||||
|
@test count(a->a==0, y) > 50
|
||||||
|
|
||||||
x = rand(100)
|
x = rand(100)
|
||||||
m = Chain(Dense(100,100),
|
m = Chain(Dense(100,100),
|
||||||
Dropout(0.9))
|
Dropout(0.9))
|
||||||
y = m(x)
|
y = m(x)
|
||||||
@test count(a->a == 0, y) > 50
|
@test count(a->a == 0, y) > 50
|
||||||
setmode!(m, :eval)
|
testmode!(m)
|
||||||
y = m(x)
|
y = m(x)
|
||||||
@test count(a->a == 0, y) == 0
|
@test count(a->a == 0, y) == 0
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user