GroupNorm made to use istraining()
This commit is contained in:
parent
a56cfb73c3
commit
11073dcd25
@ -295,7 +295,6 @@ mutable struct GroupNorm{F,V,W,N,T}
|
||||
σ²::W # moving std
|
||||
ϵ::N
|
||||
momentum::N
|
||||
active::Bool
|
||||
end
|
||||
|
||||
GroupNorm(chs::Integer, G::Integer, λ = identity;
|
||||
@ -326,7 +325,7 @@ function(gn::GroupNorm)(x)
|
||||
β = reshape(gn.β, affine_shape...)
|
||||
|
||||
y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches))
|
||||
if !gn.active
|
||||
if !istraining()
|
||||
og_shape = size(x)
|
||||
μ = reshape(gn.μ, μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
|
||||
σ² = reshape(gn.σ², μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
|
||||
|
Loading…
Reference in New Issue
Block a user