gpu broadcast fix

This commit is contained in:
Mike J Innes 2018-04-17 18:05:58 +01:00
parent 8f73dc6e14
commit baff20514d

View File

@ -116,7 +116,7 @@ BatchNorm(chs::Integer, λ = identity;
function (BN::BatchNorm)(x)
size(x, ndims(x)-1) == length(BN.β) ||
error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))")
λ, γ, β = BN.λ, BN.γ, BN.β
γ, β = BN.γ, BN.β
dims = length(size(x))
channels = size(x, dims-1)
affine_shape = ones(Int, dims)
@ -140,7 +140,9 @@ function (BN::BatchNorm)(x)
BN.σ = (1 - mtm) .* BN.σ .+ mtm .* squeeze(data(σ), (axes...)) .* m ./ (m - 1)
end
λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ σ) .+ reshape(β, affine_shape...))
let λ = BN.λ
λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ σ) .+ reshape(β, affine_shape...))
end
end
children(BN::BatchNorm) =