diff --git a/src/tracker/lib.jl b/src/tracker/lib.jl index 018a25fc..be77634b 100644 --- a/src/tracker/lib.jl +++ b/src/tracker/lib.jl @@ -39,10 +39,13 @@ Base.vcat(a::TrackedMatrix, b::TrackedMatrix) = TrackedArray(Call(vcat, a, b)) Base.vcat(a::TrackedMatrix, b::AbstractMatrix) = TrackedArray(Call(vcat, a, b)) Base.vcat(a::AbstractMatrix, b::TrackedMatrix) = TrackedArray(Call(vcat, a, b)) -function back(::typeof(vcat), Δ, xs, ys) +function back(::typeof(vcat), Δ, xs...) i = Base.tail(map(_ -> :, size(Δ))) - @back(xs, Δ[1:size(xs,1), i...]) - @back(ys, Δ[size(xs,1)+1:end, i...]) + start = 0 + for xsi in xs + @back(xsi, Δ[start+1:start+size(xsi,1), i...]) + start += size(xsi, 1) + end end # Reductions