diff --git a/src/functor.jl b/src/functor.jl index a36b5765..4edfbd98 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -39,6 +39,19 @@ end trainable(m) = functor(m)[1] +""" + testmode!(m, mode = true) + +Set a layer or model's test mode (see below). +Using `:auto` mode will treat any gradient computation as training. + +Possible values include: +- `false` for training +- `true` for testing +- `:auto` or `nothing` for Flux to detect the mode automatically +""" +testmode!(m, mode) = nothing + params!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x) function params!(p::Params, x, seen = IdSet()) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 2a465208..6788f761 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -33,6 +33,8 @@ applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x)) Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...) +testmode!(m::Chain, mode = true) = map(x -> testmode!(x, mode), m.layers) + function Base.show(io::IO, c::Chain) print(io, "Chain(") join(io, c.layers, ", ") diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index ee6b6fdd..7b438bc2 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -3,21 +3,6 @@ istraining() = false @adjoint istraining() = true, _ -> nothing _isactive(m) = isnothing(m.active) ? istraining() : m.active -# @adjoint _isactive(m) = _isactive(m), Δ -> nothing - -""" - testmode!(m, mode = :auto) - -Set a layer or model's test mode (see below). -Using `:auto` mode will treat any gradient computation as training. - -Possible values include: -- `false` for training -- `true` for testing -- `:auto` or `nothing` for Flux to detect the mode automatically -""" -testmode!(m, mode) = nothing -testmode!(m::Chain, mode = :auto) = map(x -> testmode!(x, mode), m.layers) _dropout_shape(s, ::Colon) = size(s) _dropout_shape(s, dims) = tuple((i ∉ dims ? 1 : si for (i, si) ∈ enumerate(size(s)))...) @@ -58,7 +43,7 @@ function (a::Dropout)(x) return dropout(x, a.p; dims = a.dims) end -testmode!(m::Dropout, mode = :auto) = +testmode!(m::Dropout, mode = true) = (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode) function Base.show(io::IO, d::Dropout) @@ -97,7 +82,7 @@ function (a::AlphaDropout)(x) return x end -testmode!(m::AlphaDropout, mode = :auto) = +testmode!(m::AlphaDropout, mode = true) = (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode) """ @@ -205,7 +190,7 @@ end @functor BatchNorm -testmode!(m::BatchNorm, mode = :auto) = +testmode!(m::BatchNorm, mode = true) = (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode) function Base.show(io::IO, l::BatchNorm) @@ -304,7 +289,7 @@ end @functor InstanceNorm -testmode!(m::InstanceNorm, mode = :auto) = +testmode!(m::InstanceNorm, mode = true) = (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode) function Base.show(io::IO, l::InstanceNorm) @@ -407,7 +392,7 @@ end @functor GroupNorm -testmode!(m::GroupNorm, mode = :auto) = +testmode!(m::GroupNorm, mode = true) = (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode) function Base.show(io::IO, l::GroupNorm)