Updated to place function definitions in the appropriate places.
This commit is contained in:
parent
7c12af065a
commit
924b8f49ec
@ -39,6 +39,19 @@ end
|
|||||||
|
|
||||||
trainable(m) = functor(m)[1]
|
trainable(m) = functor(m)[1]
|
||||||
|
|
||||||
|
"""
|
||||||
|
testmode!(m, mode = true)
|
||||||
|
|
||||||
|
Set a layer or model's test mode (see below).
|
||||||
|
Using `:auto` mode will treat any gradient computation as training.
|
||||||
|
|
||||||
|
Possible values include:
|
||||||
|
- `false` for training
|
||||||
|
- `true` for testing
|
||||||
|
- `:auto` or `nothing` for Flux to detect the mode automatically
|
||||||
|
"""
|
||||||
|
testmode!(m, mode) = nothing
|
||||||
|
|
||||||
params!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x)
|
params!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x)
|
||||||
|
|
||||||
function params!(p::Params, x, seen = IdSet())
|
function params!(p::Params, x, seen = IdSet())
|
||||||
|
@ -33,6 +33,8 @@ applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x))
|
|||||||
|
|
||||||
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
|
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
|
||||||
|
|
||||||
|
testmode!(m::Chain, mode = true) = map(x -> testmode!(x, mode), m.layers)
|
||||||
|
|
||||||
function Base.show(io::IO, c::Chain)
|
function Base.show(io::IO, c::Chain)
|
||||||
print(io, "Chain(")
|
print(io, "Chain(")
|
||||||
join(io, c.layers, ", ")
|
join(io, c.layers, ", ")
|
||||||
|
@ -3,21 +3,6 @@ istraining() = false
|
|||||||
@adjoint istraining() = true, _ -> nothing
|
@adjoint istraining() = true, _ -> nothing
|
||||||
|
|
||||||
_isactive(m) = isnothing(m.active) ? istraining() : m.active
|
_isactive(m) = isnothing(m.active) ? istraining() : m.active
|
||||||
# @adjoint _isactive(m) = _isactive(m), Δ -> nothing
|
|
||||||
|
|
||||||
"""
|
|
||||||
testmode!(m, mode = :auto)
|
|
||||||
|
|
||||||
Set a layer or model's test mode (see below).
|
|
||||||
Using `:auto` mode will treat any gradient computation as training.
|
|
||||||
|
|
||||||
Possible values include:
|
|
||||||
- `false` for training
|
|
||||||
- `true` for testing
|
|
||||||
- `:auto` or `nothing` for Flux to detect the mode automatically
|
|
||||||
"""
|
|
||||||
testmode!(m, mode) = nothing
|
|
||||||
testmode!(m::Chain, mode = :auto) = map(x -> testmode!(x, mode), m.layers)
|
|
||||||
|
|
||||||
_dropout_shape(s, ::Colon) = size(s)
|
_dropout_shape(s, ::Colon) = size(s)
|
||||||
_dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) ∈ enumerate(size(s)))...)
|
_dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) ∈ enumerate(size(s)))...)
|
||||||
@ -58,7 +43,7 @@ function (a::Dropout)(x)
|
|||||||
return dropout(x, a.p; dims = a.dims)
|
return dropout(x, a.p; dims = a.dims)
|
||||||
end
|
end
|
||||||
|
|
||||||
testmode!(m::Dropout, mode = :auto) =
|
testmode!(m::Dropout, mode = true) =
|
||||||
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode)
|
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode)
|
||||||
|
|
||||||
function Base.show(io::IO, d::Dropout)
|
function Base.show(io::IO, d::Dropout)
|
||||||
@ -97,7 +82,7 @@ function (a::AlphaDropout)(x)
|
|||||||
return x
|
return x
|
||||||
end
|
end
|
||||||
|
|
||||||
testmode!(m::AlphaDropout, mode = :auto) =
|
testmode!(m::AlphaDropout, mode = true) =
|
||||||
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode)
|
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
@ -205,7 +190,7 @@ end
|
|||||||
|
|
||||||
@functor BatchNorm
|
@functor BatchNorm
|
||||||
|
|
||||||
testmode!(m::BatchNorm, mode = :auto) =
|
testmode!(m::BatchNorm, mode = true) =
|
||||||
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode)
|
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode)
|
||||||
|
|
||||||
function Base.show(io::IO, l::BatchNorm)
|
function Base.show(io::IO, l::BatchNorm)
|
||||||
@ -304,7 +289,7 @@ end
|
|||||||
|
|
||||||
@functor InstanceNorm
|
@functor InstanceNorm
|
||||||
|
|
||||||
testmode!(m::InstanceNorm, mode = :auto) =
|
testmode!(m::InstanceNorm, mode = true) =
|
||||||
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode)
|
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode)
|
||||||
|
|
||||||
function Base.show(io::IO, l::InstanceNorm)
|
function Base.show(io::IO, l::InstanceNorm)
|
||||||
@ -407,7 +392,7 @@ end
|
|||||||
|
|
||||||
@functor GroupNorm
|
@functor GroupNorm
|
||||||
|
|
||||||
testmode!(m::GroupNorm, mode = :auto) =
|
testmode!(m::GroupNorm, mode = true) =
|
||||||
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode)
|
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode)
|
||||||
|
|
||||||
function Base.show(io::IO, l::GroupNorm)
|
function Base.show(io::IO, l::GroupNorm)
|
||||||
|
Loading…
Reference in New Issue
Block a user