gpu broadcast fix
This commit is contained in:
parent
8f73dc6e14
commit
baff20514d
@ -116,7 +116,7 @@ BatchNorm(chs::Integer, λ = identity;
|
||||
function (BN::BatchNorm)(x)
|
||||
size(x, ndims(x)-1) == length(BN.β) ||
|
||||
error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))")
|
||||
λ, γ, β = BN.λ, BN.γ, BN.β
|
||||
γ, β = BN.γ, BN.β
|
||||
dims = length(size(x))
|
||||
channels = size(x, dims-1)
|
||||
affine_shape = ones(Int, dims)
|
||||
@ -140,8 +140,10 @@ function (BN::BatchNorm)(x)
|
||||
BN.σ = (1 - mtm) .* BN.σ .+ mtm .* squeeze(data(σ), (axes...)) .* m ./ (m - 1)
|
||||
end
|
||||
|
||||
let λ = BN.λ
|
||||
λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ σ) .+ reshape(β, affine_shape...))
|
||||
end
|
||||
end
|
||||
|
||||
children(BN::BatchNorm) =
|
||||
(BN.λ, BN.β, BN.γ, BN.μ, BN.σ, BN.ϵ, BN.momentum, BN.active)
|
||||
|
Loading…
Reference in New Issue
Block a user