Changed testmode! to return model
This commit is contained in:
parent
ba5259a269
commit
5cbd2cecf2
|
@ -50,7 +50,7 @@ Possible values include:
|
||||||
- `true` for testing
|
- `true` for testing
|
||||||
- `:auto` or `nothing` for Flux to detect the mode automatically
|
- `:auto` or `nothing` for Flux to detect the mode automatically
|
||||||
"""
|
"""
|
||||||
testmode!(m, mode) = nothing
|
testmode!(m, mode = true) = m
|
||||||
|
|
||||||
params!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x)
|
params!(p::Params, x::AbstractArray{<:Number}, seen = IdSet()) = push!(p, x)
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,7 @@ applychain(fs::Tuple, x) = applychain(tail(fs), first(fs)(x))
|
||||||
|
|
||||||
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
|
Base.getindex(c::Chain, i::AbstractArray) = Chain(c.layers[i]...)
|
||||||
|
|
||||||
testmode!(m::Chain, mode = true) = map(x -> testmode!(x, mode), m.layers)
|
testmode!(m::Chain, mode = true) = (map(x -> testmode!(x, mode), m.layers); m)
|
||||||
|
|
||||||
function Base.show(io::IO, c::Chain)
|
function Base.show(io::IO, c::Chain)
|
||||||
print(io, "Chain(")
|
print(io, "Chain(")
|
||||||
|
|
|
@ -44,7 +44,7 @@ function (a::Dropout)(x)
|
||||||
end
|
end
|
||||||
|
|
||||||
testmode!(m::Dropout, mode = true) =
|
testmode!(m::Dropout, mode = true) =
|
||||||
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode)
|
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
|
||||||
|
|
||||||
function Base.show(io::IO, d::Dropout)
|
function Base.show(io::IO, d::Dropout)
|
||||||
print(io, "Dropout(", d.p)
|
print(io, "Dropout(", d.p)
|
||||||
|
@ -83,7 +83,7 @@ function (a::AlphaDropout)(x)
|
||||||
end
|
end
|
||||||
|
|
||||||
testmode!(m::AlphaDropout, mode = true) =
|
testmode!(m::AlphaDropout, mode = true) =
|
||||||
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode)
|
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
LayerNorm(h::Integer)
|
LayerNorm(h::Integer)
|
||||||
|
@ -191,7 +191,7 @@ end
|
||||||
@functor BatchNorm
|
@functor BatchNorm
|
||||||
|
|
||||||
testmode!(m::BatchNorm, mode = true) =
|
testmode!(m::BatchNorm, mode = true) =
|
||||||
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode)
|
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
|
||||||
|
|
||||||
function Base.show(io::IO, l::BatchNorm)
|
function Base.show(io::IO, l::BatchNorm)
|
||||||
print(io, "BatchNorm($(join(size(l.β), ", "))")
|
print(io, "BatchNorm($(join(size(l.β), ", "))")
|
||||||
|
@ -290,7 +290,7 @@ end
|
||||||
@functor InstanceNorm
|
@functor InstanceNorm
|
||||||
|
|
||||||
testmode!(m::InstanceNorm, mode = true) =
|
testmode!(m::InstanceNorm, mode = true) =
|
||||||
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode)
|
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
|
||||||
|
|
||||||
function Base.show(io::IO, l::InstanceNorm)
|
function Base.show(io::IO, l::InstanceNorm)
|
||||||
print(io, "InstanceNorm($(join(size(l.β), ", "))")
|
print(io, "InstanceNorm($(join(size(l.β), ", "))")
|
||||||
|
@ -393,7 +393,7 @@ end
|
||||||
@functor GroupNorm
|
@functor GroupNorm
|
||||||
|
|
||||||
testmode!(m::GroupNorm, mode = true) =
|
testmode!(m::GroupNorm, mode = true) =
|
||||||
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode)
|
(m.active = (isnothing(mode) || mode == :auto) ? nothing : !mode; m)
|
||||||
|
|
||||||
function Base.show(io::IO, l::GroupNorm)
|
function Base.show(io::IO, l::GroupNorm)
|
||||||
print(io, "GroupNorm($(join(size(l.β), ", "))")
|
print(io, "GroupNorm($(join(size(l.β), ", "))")
|
||||||
|
|
|
@ -85,19 +85,19 @@ end
|
||||||
@test isapprox(y, sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ)), atol = 1.0e-7)
|
@test isapprox(y, sigmoid.((x .- m.μ) ./ sqrt.(m.σ² .+ m.ϵ)), atol = 1.0e-7)
|
||||||
end
|
end
|
||||||
|
|
||||||
let m = trainmode(BatchNorm(2)), x = reshape(Float32.(1:6), 3, 2, 1)
|
let m = testmode!(BatchNorm(2), false), x = reshape(Float32.(1:6), 3, 2, 1)
|
||||||
y = reshape(permutedims(x, [2, 1, 3]), 2, :)
|
y = reshape(permutedims(x, [2, 1, 3]), 2, :)
|
||||||
y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3])
|
y = permutedims(reshape(m(y), 2, 3, 1), [2, 1, 3])
|
||||||
@test m(x) == y
|
@test m(x) == y
|
||||||
end
|
end
|
||||||
|
|
||||||
let m = trainmode(BatchNorm(2)), x = reshape(Float32.(1:12), 2, 3, 2, 1)
|
let m = testmode!(BatchNorm(2), false), x = reshape(Float32.(1:12), 2, 3, 2, 1)
|
||||||
y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :)
|
y = reshape(permutedims(x, [3, 1, 2, 4]), 2, :)
|
||||||
y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4])
|
y = permutedims(reshape(m(y), 2, 2, 3, 1), [2, 3, 1, 4])
|
||||||
@test m(x) == y
|
@test m(x) == y
|
||||||
end
|
end
|
||||||
|
|
||||||
let m = trainmode(BatchNorm(2)), x = reshape(Float32.(1:24), 2, 2, 3, 2, 1)
|
let m = testmode!(BatchNorm(2), false), x = reshape(Float32.(1:24), 2, 2, 3, 2, 1)
|
||||||
y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :)
|
y = reshape(permutedims(x, [4, 1, 2, 3, 5]), 2, :)
|
||||||
y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5])
|
y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5])
|
||||||
@test m(x) == y
|
@test m(x) == y
|
||||||
|
@ -165,7 +165,7 @@ end
|
||||||
@test isapprox(y, sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ)), atol = 1.0e-7)
|
@test isapprox(y, sigmoid.((x .- expand_inst(m.μ, affine_shape)) ./ sqrt.(expand_inst(m.σ², affine_shape) .+ m.ϵ)), atol = 1.0e-7)
|
||||||
end
|
end
|
||||||
|
|
||||||
let m = trainmode(InstanceNorm(2)), sizes = (2, 4, 1, 2, 3),
|
let m = testmode!(InstanceNorm(2), false), sizes = (2, 4, 1, 2, 3),
|
||||||
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||||
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
|
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
|
||||||
y = reshape(m(y), sizes...)
|
y = reshape(m(y), sizes...)
|
||||||
|
@ -182,7 +182,7 @@ end
|
||||||
end
|
end
|
||||||
|
|
||||||
# show that instance norm is equal to batch norm when channel and batch dims are squashed
|
# show that instance norm is equal to batch norm when channel and batch dims are squashed
|
||||||
let m_inorm = trainmode(InstanceNorm(2)), m_bnorm = trainmode(BatchNorm(12)), sizes = (5, 5, 3, 4, 2, 6),
|
let m_inorm = testmode!(InstanceNorm(2), false), m_bnorm = testmode!(BatchNorm(12), false), sizes = (5, 5, 3, 4, 2, 6),
|
||||||
x = reshape(Float32.(collect(1:prod(sizes))), sizes)
|
x = reshape(Float32.(collect(1:prod(sizes))), sizes)
|
||||||
@test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes)
|
@test m_inorm(x) == reshape(m_bnorm(reshape(x, (sizes[1:end - 2]..., :, 1))), sizes)
|
||||||
end
|
end
|
||||||
|
@ -266,7 +266,7 @@ if VERSION >= v"1.1"
|
||||||
@test isapprox(y, out, atol = 1.0e-7)
|
@test isapprox(y, out, atol = 1.0e-7)
|
||||||
end
|
end
|
||||||
|
|
||||||
let m = trainmode(GroupNorm(2,2)), sizes = (2, 4, 1, 2, 3),
|
let m = testmode!(GroupNorm(2,2), false), sizes = (2, 4, 1, 2, 3),
|
||||||
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||||
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
|
y = reshape(permutedims(x, [3, 1, 2, 4, 5]), :, 2, 3)
|
||||||
y = reshape(m(y), sizes...)
|
y = reshape(m(y), sizes...)
|
||||||
|
@ -283,13 +283,13 @@ if VERSION >= v"1.1"
|
||||||
end
|
end
|
||||||
|
|
||||||
# show that group norm is the same as instance norm when the group size is the same as the number of channels
|
# show that group norm is the same as instance norm when the group size is the same as the number of channels
|
||||||
let IN = trainmode(InstanceNorm(4)), GN = trainmode(GroupNorm(4,4)), sizes = (2,2,3,4,5),
|
let IN = testmode!(InstanceNorm(4), false), GN = testmode!(GroupNorm(4,4), false), sizes = (2,2,3,4,5),
|
||||||
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||||
@test IN(x) ≈ GN(x)
|
@test IN(x) ≈ GN(x)
|
||||||
end
|
end
|
||||||
|
|
||||||
# show that group norm is the same as batch norm for a group of size 1 and batch of size 1
|
# show that group norm is the same as batch norm for a group of size 1 and batch of size 1
|
||||||
let BN = trainmode(BatchNorm(4)), GN = trainmode(GroupNorm(4,4)), sizes = (2,2,3,4,1),
|
let BN = testmode!(BatchNorm(4), false), GN = testmode!(GroupNorm(4,4), false), sizes = (2,2,3,4,1),
|
||||||
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
x = Float32.(reshape(collect(1:prod(sizes)), sizes))
|
||||||
@test BN(x) ≈ GN(x)
|
@test BN(x) ≈ GN(x)
|
||||||
end
|
end
|
||||||
|
|
Loading…
Reference in New Issue