changes based on the improved batchnorm in PR#633
This commit is contained in:
parent
129a708b6f
commit
c41f891005
@ -185,6 +185,7 @@ m = Chain(
|
|||||||
softmax)
|
softmax)
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
expand_inst = (x, as) -> reshape(repeat(x, outer=[1, as[length(as)]]), as...)
|
||||||
mutable struct InstanceNorm{F,V,W,N}
|
mutable struct InstanceNorm{F,V,W,N}
|
||||||
λ::F # activation function
|
λ::F # activation function
|
||||||
β::V # bias
|
β::V # bias
|
||||||
@ -207,7 +208,6 @@ function (IN::InstanceNorm)(x)
|
|||||||
ndims(x) > 2 ||
|
ndims(x) > 2 ||
|
||||||
error("InstanceNorm requires at least 3 dimensions. With 2 dimensions an array of zeros would be returned")
|
error("InstanceNorm requires at least 3 dimensions. With 2 dimensions an array of zeros would be returned")
|
||||||
# these are repeated later on depending on the batch size
|
# these are repeated later on depending on the batch size
|
||||||
γ, β = IN.γ, IN.β
|
|
||||||
dims = length(size(x))
|
dims = length(size(x))
|
||||||
c = size(x, dims-1)
|
c = size(x, dims-1)
|
||||||
bs = size(x, dims)
|
bs = size(x, dims)
|
||||||
@ -215,10 +215,12 @@ function (IN::InstanceNorm)(x)
|
|||||||
affine_shape[end-1] = c
|
affine_shape[end-1] = c
|
||||||
affine_shape[end] = bs
|
affine_shape[end] = bs
|
||||||
m = prod(size(x)[1:end-2])
|
m = prod(size(x)[1:end-2])
|
||||||
|
γ, β = expand_inst(IN.γ, affine_shape), expand_inst(IN.β, affine_shape)
|
||||||
|
|
||||||
if !IN.active
|
if !IN.active
|
||||||
μ = reshape(repeat(IN.μ, outer=[bs]), affine_shape...)
|
μ = expand_inst(IN.μ, affine_shape)
|
||||||
σ² = reshape(repeat(IN.σ², outer=[bs]), affine_shape...)
|
σ² = expand_inst(IN.σ², affine_shape)
|
||||||
|
ϵ = IN.ϵ
|
||||||
else
|
else
|
||||||
T = eltype(x)
|
T = eltype(x)
|
||||||
|
|
||||||
@ -229,14 +231,13 @@ function (IN::InstanceNorm)(x)
|
|||||||
|
|
||||||
# update moving mean/std
|
# update moving mean/std
|
||||||
mtm = data(convert(T, IN.momentum))
|
mtm = data(convert(T, IN.momentum))
|
||||||
IN.μ = reshape(mean((1 - mtm) .* repeat(IN.μ, outer=[1, bs]) .+ mtm .* reshape(data(μ), (c, bs)), dims = 2), :)
|
IN.μ = reshape(mean(repeat((1 - mtm) .* IN.μ, outer=[1, bs]) .+ mtm .* reshape(data(μ), (c, bs)), dims = 2), :)
|
||||||
IN.σ² = reshape(mean(((1 - mtm) .* repeat(IN.σ², outer=[1, bs]) .+ mtm .* reshape(data(σ²), (c, bs)) .* (m / (m - 1))), dims = 2), :)
|
IN.σ² = reshape(mean((repeat((1 - mtm) .* IN.σ², outer=[1, bs]) .+ reshape(data(σ²), (c, bs)) .* (mtm * m / (m - 1))), dims = 2), :)
|
||||||
end
|
end
|
||||||
|
|
||||||
let λ = IN.λ
|
let λ = IN.λ
|
||||||
temp = reshape(repeat(γ, outer=[bs]), affine_shape...) .* ((x .- μ) ./ sqrt.(σ² .+ IN.ϵ))
|
x̂ = (x .- μ) ./ sqrt.(σ² .+ ϵ)
|
||||||
# This is intentionally not fused because of an extreme slowdown doing so
|
λ.(γ .* x̂ .+ β)
|
||||||
λ.(temp .+ reshape(repeat(β, outer=[bs]), affine_shape...))
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user