diff --git a/src/layers/normalise.jl b/src/layers/normalise.jl index 40e186dc..85224ed2 100644 --- a/src/layers/normalise.jl +++ b/src/layers/normalise.jl @@ -123,9 +123,12 @@ function (BN::BatchNorm)(x) γ, β = BN.γ, BN.β dims = length(size(x)) channels = size(x, dims-1) - affine_shape = ones(Int, dims) - affine_shape[end-1] = channels - m = prod(size(x)[1:end-2]) * size(x)[end] + affine_shape = let dims=dims, channels=channels + ntuple(i->i == dims - 1 ? channels : 1, dims) + end + m = let sz = size(x) + prod(ntuple(i->sz[i], dims-2)) * sz[end] + end if !BN.active μ = reshape(BN.μ, affine_shape...) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 3d9836d0..618ff1e2 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -421,30 +421,3 @@ function Base.Broadcast.materialize(bc::Broadcasted{TrackedStyle}) end using Requires - -# https://github.com/FluxML/Flux.jl/issues/353 -@init Requires.isprecompiling() || @eval Base.Broadcast begin - function flatten(bc::Broadcasted{Style}) where {Style} - isflat(bc) && return bc - args = cat_nested(bc) - let makeargs = make_makeargs(bc), f = bc.f - newf = @inline function(args::Vararg{Any,N}) where N - f(makeargs(args...)...) - end - return Broadcasted{Style}(newf, args, bc.axes) - end - end - @inline function make_makeargs(makeargs, t::Tuple{<:Broadcasted,Vararg{Any}}) - bc = t[1] - let makeargs = make_makeargs(makeargs, tail(t)), f = bc.f - let makeargs = make_makeargs(makeargs, bc.args) - headargs, tailargs = make_headargs(bc.args), make_tailargs(bc.args) - return @inline function(args::Vararg{Any,N}) where N - args1 = makeargs(args...) - a, b = headargs(args1...), tailargs(args1...) - (f(a...), b...) - end - end - end - end -end