From 5cbd2cecf29cf58a4e4bd97e637515c299a522d8 Mon Sep 17 00:00:00 2001 From: Kyle Daruwalla Date: Sat, 29 Feb 2020 16:09:59 -0600 Subject: [PATCH] Changed testmode! to return model --- src/functor.jl | 2 +- src/layers/basic.jl | 2 +- src/layers/normalise.jl | 10 +++++----- test/layers/normalisation.jl | 16 ++++++++-------- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/src/functor.jl b/src/functor.jl index 4edfbd98..ee384b98 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -50,7 +50,7 @@ Possible values include: - `true` for testing - `:auto` or `nothing` for Flux to detect the mode automatically """ -testmode!(m, mode) = nothing +testmode!(m, mode = true) = m params!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x) diff --git a/src/layers/basic.jl b/src/layers/basic.jl index 6788f761..10d1f07b 100644 --- a/src/layers/basic.jl +++ b/src/layers/basic.jl @@ -33,7 +33,7 @@ applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x)) Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...) -testmode!(m::Chain, mode = true) = map(x -> testmode!(x, mode), m.layers) +testmode!(m::Chain, mode = true) = (map(x -> testmode!(x, mode), m.layers); m) function Base.show(io::IO, c::Chain) print(io, "Chain(") diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 7b438bc2..36c6d2bd 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -44,7 +44,7 @@ function (a::Dropout)(x) end testmode!(m::Dropout, mode = true) = - (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode) + (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) function Base.show(io::IO, d::Dropout) print(io, "Dropout(", d.p) @@ -83,7 +83,7 @@ function (a::AlphaDropout)(x) end testmode!(m::AlphaDropout, mode = true) = - (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode) + (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) """ LayerNorm(h::Integer) @@ -191,7 +191,7 @@ end @functor BatchNorm testmode!(m::BatchNorm, mode = true) = - (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode) + (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) function Base.show(io::IO, l::BatchNorm) print(io, "BatchNorm($(join(size(l.β), ", "))") @@ -290,7 +290,7 @@ end @functor InstanceNorm testmode!(m::InstanceNorm, mode = true) = - (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode) + (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) function Base.show(io::IO, l::InstanceNorm) print(io, "InstanceNorm($(join(size(l.β), ", "))") @@ -393,7 +393,7 @@ end @functor GroupNorm testmode!(m::GroupNorm, mode = true) = - (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode) + (m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m) function Base.show(io::IO, l::GroupNorm) print(io, "GroupNorm($(join(size(l.β), ", "))") diff --git a/test/layers/normalisation.jl b/test/layers/normalisation.jl index 594fb586..79bd9c77 100644 --- a/test/layers/normalisation.jl +++ b/test/layers/normalisation.jl @@ -85,19 +85,19 @@ end @test isapprox(y, sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ)), atol = 1.0e-7) end - let m = trainmode(BatchNorm(2)), x = reshape(Float32.(1:6), 3, 2, 1) + let m = testmode!(BatchNorm(2), false), x = reshape(Float32.(1:6), 3, 2, 1) y = reshape(permutedims(x, [2, 1, 3]), 2, :) y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3]) @test m(x) == y end - let m = trainmode(BatchNorm(2)), x = reshape(Float32.(1:12), 2, 3, 2, 1) + let m = testmode!(BatchNorm(2), false), x = reshape(Float32.(1:12), 2, 3, 2, 1) y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :) y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4]) @test m(x) == y end - let m = trainmode(BatchNorm(2)), x = reshape(Float32.(1:24), 2, 2, 3, 2, 1) + let m = testmode!(BatchNorm(2), false), x = reshape(Float32.(1:24), 2, 2, 3, 2, 1) y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :) y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5]) @test m(x) == y @@ -165,7 +165,7 @@ end @test isapprox(y, sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ)), atol = 1.0e-7) end - let m = trainmode(InstanceNorm(2)), sizes = (2, 4, 1, 2, 3), + let m = testmode!(InstanceNorm(2), false), sizes = (2, 4, 1, 2, 3), x = Float32.(reshape(collect(1:prod(sizes)), sizes)) y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3) y = reshape(m(y), sizes...) @@ -182,7 +182,7 @@ end end # show that instance norm is equal to batch norm when channel and batch dims are squashed - let m_inorm = trainmode(InstanceNorm(2)), m_bnorm = trainmode(BatchNorm(12)), sizes = (5, 5, 3, 4, 2, 6), + let m_inorm = testmode!(InstanceNorm(2), false), m_bnorm = testmode!(BatchNorm(12), false), sizes = (5, 5, 3, 4, 2, 6), x = reshape(Float32.(collect(1:prod(sizes))), sizes) @test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes) end @@ -266,7 +266,7 @@ if VERSION >= v"1.1" @test isapprox(y, out, atol = 1.0e-7) end - let m = trainmode(GroupNorm(2,2)), sizes = (2, 4, 1, 2, 3), + let m = testmode!(GroupNorm(2,2), false), sizes = (2, 4, 1, 2, 3), x = Float32.(reshape(collect(1:prod(sizes)), sizes)) y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3) y = reshape(m(y), sizes...) @@ -283,13 +283,13 @@ if VERSION >= v"1.1" end # show that group norm is the same as instance norm when the group size is the same as the number of channels - let IN = trainmode(InstanceNorm(4)), GN = trainmode(GroupNorm(4,4)), sizes = (2,2,3,4,5), + let IN = testmode!(InstanceNorm(4), false), GN = testmode!(GroupNorm(4,4), false), sizes = (2,2,3,4,5), x = Float32.(reshape(collect(1:prod(sizes)), sizes)) @test IN(x) ≈ GN(x) end # show that group norm is the same as batch norm for a group of size 1 and batch of size 1 - let BN = trainmode(BatchNorm(4)), GN = trainmode(GroupNorm(4,4)), sizes = (2,2,3,4,1), + let BN = testmode!(BatchNorm(4), false), GN = testmode!(GroupNorm(4,4), false), sizes = (2,2,3,4,1), x = Float32.(reshape(collect(1:prod(sizes)), sizes)) @test BN(x) ≈ GN(x) end