gpu broadcast fix
This commit is contained in:
parent
8f73dc6e14
commit
baff20514d
@ -116,7 +116,7 @@ BatchNorm(chs::Integer, λ = identity;
|
|||||||
function (BN::BatchNorm)(x)
|
function (BN::BatchNorm)(x)
|
||||||
size(x, ndims(x)-1) == length(BN.β) ||
|
size(x, ndims(x)-1) == length(BN.β) ||
|
||||||
error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))")
|
error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))")
|
||||||
λ, γ, β = BN.λ, BN.γ, BN.β
|
γ, β = BN.γ, BN.β
|
||||||
dims = length(size(x))
|
dims = length(size(x))
|
||||||
channels = size(x, dims-1)
|
channels = size(x, dims-1)
|
||||||
affine_shape = ones(Int, dims)
|
affine_shape = ones(Int, dims)
|
||||||
@ -140,7 +140,9 @@ function (BN::BatchNorm)(x)
|
|||||||
BN.σ = (1 - mtm) .* BN.σ .+ mtm .* squeeze(data(σ), (axes...)) .* m ./ (m - 1)
|
BN.σ = (1 - mtm) .* BN.σ .+ mtm .* squeeze(data(σ), (axes...)) .* m ./ (m - 1)
|
||||||
end
|
end
|
||||||
|
|
||||||
λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ σ) .+ reshape(β, affine_shape...))
|
let λ = BN.λ
|
||||||
|
λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ σ) .+ reshape(β, affine_shape...))
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
children(BN::BatchNorm) =
|
children(BN::BatchNorm) =
|
||||||
|
Loading…
Reference in New Issue
Block a user