GroupNorm made to use istraining()

This commit is contained in:
thebhatman 2019-06-11 22:04:33 +05:30
parent a56cfb73c3
commit 11073dcd25

View File

@ -295,7 +295,6 @@ mutable struct GroupNorm{F,V,W,N,T}
σ²::W # moving std
ϵ::N
momentum::N
active::Bool
end
GroupNorm(chs::Integer, G::Integer, λ = identity;
@ -326,7 +325,7 @@ function(gn::GroupNorm)(x)
β = reshape(gn.β, affine_shape...)
y = reshape(x,((size(x))[1:end-2]...,channels_per_group,groups,batches))
if !gn.active
if !istraining()
og_shape = size(x)
μ = reshape(gn.μ, μ_affine_shape...) # Shape : (1,1,...C/G,G,1)
σ² = reshape(gn.σ², μ_affine_shape...) # Shape : (1,1,...C/G,G,1)