Merge pull request #586 from KristofferC/kc/batchnorm
work around extreme slowdown in BatchNorm due to julia performance bug in broadcast fusion
This commit is contained in:
commit
601e2d8ae0
|
@ -138,7 +138,9 @@ function (BN::BatchNorm)(x)
|
|||
end
|
||||
|
||||
let λ = BN.λ
|
||||
λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ sqrt.(σ² .+ BN.ϵ)) .+ reshape(β, affine_shape...))
|
||||
temp = reshape(γ, affine_shape...) .* ((x .- μ) ./ sqrt.(σ² .+ BN.ϵ))
|
||||
# This is intentionally not fused because of an extreme slowdown doing so
|
||||
λ.(temp .+ reshape(β, affine_shape...))
|
||||
end
|
||||
end
|
||||
|
||||
|
|
|
@ -98,4 +98,9 @@ end
|
|||
y = permutedims(reshape(m(y), 2, 2, 2, 3, 1), [2, 3, 4, 1, 5])
|
||||
@test m(x) == y
|
||||
end
|
||||
|
||||
let m = BatchNorm(32), x = randn(Float32, 416, 416, 32, 1);
|
||||
m(x)
|
||||
@test (@allocated m(x)) < 100_000_000
|
||||
end
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue