change IN to in
This commit is contained in:
parent
83b4b3a714
commit
7b9b64f1cb
@ -203,9 +203,9 @@ InstanceNorm(chs::Integer, λ = identity;
|
|||||||
InstanceNorm(λ, param(initβ(chs)), param(initγ(chs)),
|
InstanceNorm(λ, param(initβ(chs)), param(initγ(chs)),
|
||||||
zeros(chs), ones(chs), ϵ, momentum, true)
|
zeros(chs), ones(chs), ϵ, momentum, true)
|
||||||
|
|
||||||
function (IN::InstanceNorm)(x)
|
function (in::InstanceNorm)(x)
|
||||||
size(x, ndims(x)-1) == length(IN.β) ||
|
size(x, ndims(x)-1) == length(in.β) ||
|
||||||
error("InstanceNorm expected $(length(IN.β)) channels, got $(size(x, ndims(x)-1))")
|
error("InstanceNorm expected $(length(in.β)) channels, got $(size(x, ndims(x)-1))")
|
||||||
ndims(x) > 2 ||
|
ndims(x) > 2 ||
|
||||||
error("InstanceNorm requires at least 3 dimensions. With 2 dimensions an array of zeros would be returned")
|
error("InstanceNorm requires at least 3 dimensions. With 2 dimensions an array of zeros would be returned")
|
||||||
# these are repeated later on depending on the batch size
|
# these are repeated later on depending on the batch size
|
||||||
@ -216,39 +216,39 @@ function (IN::InstanceNorm)(x)
|
|||||||
affine_shape[end-1] = c
|
affine_shape[end-1] = c
|
||||||
affine_shape[end] = bs
|
affine_shape[end] = bs
|
||||||
m = prod(size(x)[1:end-2])
|
m = prod(size(x)[1:end-2])
|
||||||
γ, β = expand_inst(IN.γ, affine_shape), expand_inst(IN.β, affine_shape)
|
γ, β = expand_inst(in.γ, affine_shape), expand_inst(in.β, affine_shape)
|
||||||
|
|
||||||
if !IN.active
|
if !in.active
|
||||||
μ = expand_inst(IN.μ, affine_shape)
|
μ = expand_inst(in.μ, affine_shape)
|
||||||
σ² = expand_inst(IN.σ², affine_shape)
|
σ² = expand_inst(in.σ², affine_shape)
|
||||||
ϵ = IN.ϵ
|
ϵ = in.ϵ
|
||||||
else
|
else
|
||||||
T = eltype(x)
|
T = eltype(x)
|
||||||
|
|
||||||
ϵ = data(convert(T, IN.ϵ))
|
ϵ = data(convert(T, in.ϵ))
|
||||||
axes = 1:dims-2 # axes to reduce along (all but channels and batch size axes)
|
axes = 1:dims-2 # axes to reduce along (all but channels and batch size axes)
|
||||||
μ = mean(x, dims = axes)
|
μ = mean(x, dims = axes)
|
||||||
σ² = mean((x .- μ) .^ 2, dims = axes)
|
σ² = mean((x .- μ) .^ 2, dims = axes)
|
||||||
|
|
||||||
# update moving mean/std
|
# update moving mean/std
|
||||||
mtm = data(convert(T, IN.momentum))
|
mtm = data(convert(T, in.momentum))
|
||||||
IN.μ = dropdims(mean(repeat((1 - mtm) .* IN.μ, outer=[1, bs]) .+ mtm .* reshape(data(μ), (c, bs)), dims = 2), dims=2)
|
in.μ = dropdims(mean(repeat((1 - mtm) .* in.μ, outer=[1, bs]) .+ mtm .* reshape(data(μ), (c, bs)), dims = 2), dims=2)
|
||||||
IN.σ² = dropdims(mean((repeat((1 - mtm) .* IN.σ², outer=[1, bs]) .+ (mtm * m / (m - 1)) .* reshape(data(σ²), (c, bs))), dims = 2), dims=2)
|
in.σ² = dropdims(mean((repeat((1 - mtm) .* in.σ², outer=[1, bs]) .+ (mtm * m / (m - 1)) .* reshape(data(σ²), (c, bs))), dims = 2), dims=2)
|
||||||
end
|
end
|
||||||
|
|
||||||
let λ = IN.λ
|
let λ = in.λ
|
||||||
x̂ = (x .- μ) ./ sqrt.(σ² .+ ϵ)
|
x̂ = (x .- μ) ./ sqrt.(σ² .+ ϵ)
|
||||||
λ.(γ .* x̂ .+ β)
|
λ.(γ .* x̂ .+ β)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
children(IN::InstanceNorm) =
|
children(in::InstanceNorm) =
|
||||||
(IN.λ, IN.β, IN.γ, IN.μ, IN.σ², IN.ϵ, IN.momentum, IN.active)
|
(in.λ, in.β, in.γ, in.μ, in.σ², in.ϵ, in.momentum, in.active)
|
||||||
|
|
||||||
mapchildren(f, IN::InstanceNorm) = # e.g. mapchildren(cu, IN)
|
mapchildren(f, in::InstanceNorm) = # e.g. mapchildren(cu, in)
|
||||||
InstanceNorm(IN.λ, f(IN.β), f(IN.γ), f(IN.μ), f(IN.σ²), IN.ϵ, IN.momentum, IN.active)
|
InstanceNorm(in.λ, f(in.β), f(in.γ), f(in.μ), f(in.σ²), in.ϵ, in.momentum, in.active)
|
||||||
|
|
||||||
_testmode!(IN::InstanceNorm, test) = (IN.active = !test)
|
_testmode!(in::InstanceNorm, test) = (in.active = !test)
|
||||||
|
|
||||||
function Base.show(io::IO, l::InstanceNorm)
|
function Base.show(io::IO, l::InstanceNorm)
|
||||||
print(io, "InstanceNorm($(join(size(l.β), ", "))")
|
print(io, "InstanceNorm($(join(size(l.β), ", "))")
|
||||||
|
Loading…
Reference in New Issue
Block a user