Work around GPU launch bug for BatchNorm

This commit is contained in:
Elliot Saba 2019-01-11 02:27:46 -05:00
parent f0d5624ed2
commit a740dadf6a
1 changed files with 4 additions and 1 deletions

View File

@ -138,7 +138,10 @@ function (BN::BatchNorm)(x)
end
let λ = BN.λ
λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ sqrt.(σ² .+ BN.ϵ)) .+ reshape(β, affine_shape...))
# Break this up with a temporary variable to fix GPU launch bug
# https://github.com/FluxML/Flux.jl/issues/385
temp = reshape(γ, affine_shape...) .* ((x .- μ) ./ sqrt.(σ² .+ BN.ϵ))
return λ.(temp .+ reshape(β, affine_shape...))
end
end