diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index d3577a04..906a9331 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -45,8 +45,7 @@ tovec(xs::AbstractArray) = vec(xs) tovec(xs) = xs function back!(x::TrackedArray, Δ) - Δ′ = vec(x.Δ) - Δ′ .+= tovec(Δ) + x.Δ .+= Δ back!(x.f, Δ) end diff --git a/src/tracker/lib.jl b/src/tracker/lib.jl index 76f420d3..9b88ef47 100644 --- a/src/tracker/lib.jl +++ b/src/tracker/lib.jl @@ -68,9 +68,11 @@ function tracked_broadcast(f, args::Vararg{Any,N}) where N TrackedArray(Call(b, args...), b()) end +trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val{ndims(x)})) + unbroadcast(x, Δ) = size(x) == size(Δ) ? Δ : - sum(Δ, filter(n -> size(x, n) == 1, 1:ndims(Δ))) + trim(x, sum(Δ, filter(n -> size(x, n) == 1, 1:ndims(Δ)))) function back!(b::Broadcasted, Δ, args...) Δargs = ntuple(i -> Δ .* getindex.(partials.(b.data), i), length(args)) diff --git a/src/tracker/numeric.jl b/src/tracker/numeric.jl index 64ccd2ad..73c63029 100644 --- a/src/tracker/numeric.jl +++ b/src/tracker/numeric.jl @@ -1,6 +1,6 @@ function gradient(f, xs::AbstractArray...) xs = track.(xs) - back!(f(xs...), [1]) + back!(f(xs...)) grad.(xs) end