Updated to place function definitions in the appropriate places.
This commit is contained in:
parent
7c12af065a
commit
924b8f49ec
|
@ -39,6 +39,19 @@ end
|
|||
|
||||
trainable(m) = functor(m)[1]
|
||||
|
||||
"""
|
||||
testmode!(m, mode = true)
|
||||
|
||||
Set a layer or model's test mode (see below).
|
||||
Using `:auto` mode will treat any gradient computation as training.
|
||||
|
||||
Possible values include:
|
||||
- `false` for training
|
||||
- `true` for testing
|
||||
- `:auto` or `nothing` for Flux to detect the mode automatically
|
||||
"""
|
||||
testmode!(m, mode) = nothing
|
||||
|
||||
params!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x)
|
||||
|
||||
function params!(p::Params, x, seen = IdSet())
|
||||
|
|
|
@ -33,6 +33,8 @@ applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x))
|
|||
|
||||
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
|
||||
|
||||
testmode!(m::Chain, mode = true) = map(x -> testmode!(x, mode), m.layers)
|
||||
|
||||
function Base.show(io::IO, c::Chain)
|
||||
print(io, "Chain(")
|
||||
join(io, c.layers, ", ")
|
||||
|
|
|
@ -3,21 +3,6 @@ istraining() = false
|
|||
@adjoint istraining() = true, _ -> nothing
|
||||
|
||||
_isactive(m) = isnothing(m.active) ? istraining() : m.active
|
||||
# @adjoint _isactive(m) = _isactive(m), Δ -> nothing
|
||||
|
||||
"""
|
||||
testmode!(m, mode = :auto)
|
||||
|
||||
Set a layer or model's test mode (see below).
|
||||
Using `:auto` mode will treat any gradient computation as training.
|
||||
|
||||
Possible values include:
|
||||
- `false` for training
|
||||
- `true` for testing
|
||||
- `:auto` or `nothing` for Flux to detect the mode automatically
|
||||
"""
|
||||
testmode!(m, mode) = nothing
|
||||
testmode!(m::Chain, mode = :auto) = map(x -> testmode!(x, mode), m.layers)
|
||||
|
||||
_dropout_shape(s, ::Colon) = size(s)
|
||||
_dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) ∈ enumerate(size(s)))...)
|
||||
|
@ -58,7 +43,7 @@ function (a::Dropout)(x)
|
|||
return dropout(x, a.p; dims = a.dims)
|
||||
end
|
||||
|
||||
testmode!(m::Dropout, mode = :auto) =
|
||||
testmode!(m::Dropout, mode = true) =
|
||||
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode)
|
||||
|
||||
function Base.show(io::IO, d::Dropout)
|
||||
|
@ -97,7 +82,7 @@ function (a::AlphaDropout)(x)
|
|||
return x
|
||||
end
|
||||
|
||||
testmode!(m::AlphaDropout, mode = :auto) =
|
||||
testmode!(m::AlphaDropout, mode = true) =
|
||||
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode)
|
||||
|
||||
"""
|
||||
|
@ -205,7 +190,7 @@ end
|
|||
|
||||
@functor BatchNorm
|
||||
|
||||
testmode!(m::BatchNorm, mode = :auto) =
|
||||
testmode!(m::BatchNorm, mode = true) =
|
||||
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode)
|
||||
|
||||
function Base.show(io::IO, l::BatchNorm)
|
||||
|
@ -304,7 +289,7 @@ end
|
|||
|
||||
@functor InstanceNorm
|
||||
|
||||
testmode!(m::InstanceNorm, mode = :auto) =
|
||||
testmode!(m::InstanceNorm, mode = true) =
|
||||
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode)
|
||||
|
||||
function Base.show(io::IO, l::InstanceNorm)
|
||||
|
@ -407,7 +392,7 @@ end
|
|||
|
||||
@functor GroupNorm
|
||||
|
||||
testmode!(m::GroupNorm, mode = :auto) =
|
||||
testmode!(m::GroupNorm, mode = true) =
|
||||
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode)
|
||||
|
||||
function Base.show(io::IO, l::GroupNorm)
|
||||
|
|
Loading…
Reference in New Issue