diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index c6276887..f252bb2f 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -42,7 +42,8 @@ data(x::TrackedArray) = x.x grad(x::TrackedArray) = x.Δ function back!(x::TrackedArray, Δ) - x.Δ .+= Δ + Δ′ = vec(x.Δ) + Δ′ .+= vec(Δ) back!(x.f, Δ) end diff --git a/src/tracker/lib.jl b/src/tracker/lib.jl index 6003ebd0..17ead5c1 100644 --- a/src/tracker/lib.jl +++ b/src/tracker/lib.jl @@ -57,9 +57,13 @@ function tracked_broadcast(f, args::Vararg{Any,N}) where N TrackedArray(Call(b, args...), b()) end +unbroadcast(x, Δ) = + size(x) == size(Δ) ? Δ : + sum(Δ, filter(n -> size(x, n) == 1, 1:ndims(Δ))) + function back!(b::Broadcasted, Δ, args...) Δargs = ntuple(i -> Δ .* getindex.(partials.(b.data), i), length(args)) - map((x, Δ) -> @back!(x, Δ), args, Δargs) + foreach((x, Δ) -> @back!(x, unbroadcast(x, Δ)), args, Δargs) return end