This commit is contained in:
Avik Pal 2018-09-11 15:58:17 +05:30
parent 5fd8ffa47e
commit 7e83852862
2 changed files with 14 additions and 13 deletions

View File

@ -124,32 +124,35 @@ function (BN::BatchNorm)(x)
size(x, ndims(x)-1) == length(BN.β) ||
error("BatchNorm expected $(length(BN.β)) channels, got $(size(x, ndims(x)-1))")
γ, β = BN.γ, BN.β
dims = ndims(x)
dims = length(size(x))
channels = size(x, dims-1)
affine_shape = ones(Int, dims)
affine_shape[end-1] = size(x, dims-1)
T = eltype(x)
affine_shape[end-1] = channels
m = prod(size(x)[1:end-2]) * size(x)[end]
if !BN.active
μ = reshape(BN.μ, affine_shape...)
σ² = reshape(BN.σ², affine_shape...)
σ = reshape(BN.σ, affine_shape...)
else
T = eltype(x)
ϵ = data(convert(T, BN.ϵ))
axes = [1:dims-2; dims] # axes to reduce along (all but channels axis)
m = prod(size(x, axes...))
μ = mean(x, axes)
σ² = sum((x.-μ).^2, axes) ./ m
μ = mean(x, dims = axes)
σ² = sum((x .- μ) .^ 2, dims = axes) ./ m
# update moving mean/std
mtm = data(convert(T, BN.momentum))
BN.μ = ((1 - mtm) .* BN.μ .+ mtm .* dropdims(data(μ), dims = (axes...)))
BN.μ = (1 - mtm) .* BN.μ .+ mtm .* dropdims(data(μ), dims = (axes...,))
BN.σ² = ((1 - mtm) .* BN.σ² .+ mtm .* dropdims(data(σ²), dims = (axes...)) .* m ./ (m - 1))
end
ϵ = convert(T, BN.ϵ)
BN.λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ sqrt.(σ² .+ ϵ)) .+ reshape(β, affine_shape...))
let λ = BN.λ
λ.(reshape(γ, affine_shape...) .* ((x .- μ) ./ sqrt.(σ² .+ ϵ)) .+ reshape(β, affine_shape...))
end
end
treelike(BatchNorm)
@treelike BatchNorm
_testmode!(BN::BatchNorm, test) = (BN.active = !test)

View File

@ -1,8 +1,6 @@
using Flux, Flux.Tracker, CuArrays, Test
using Flux.Tracker: TrackedArray, data
@info "Testing Flux CUDNN"
@testset "CUDNN BatchNorm" begin
x = TrackedArray(rand(10, 10, 3, 1))
m = BatchNorm(3)