diff --git a/src/tracker/lib.jl b/src/tracker/lib.jl index 3ac6ddde..c0767a23 100644 --- a/src/tracker/lib.jl +++ b/src/tracker/lib.jl @@ -32,9 +32,9 @@ Base.vcat(a::TrackedVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b)) Base.vcat(a::TrackedVector, b::AbstractVector) = TrackedArray(Call(vcat, a, b)) Base.vcat(a::AbstractVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b)) -Base.vcat(a::TrackedArray, b::TrackedArray) = TrackedArray(Call(vcat, a, b)) -Base.vcat(a::TrackedArray, b::AbstractArray) = TrackedArray(Call(vcat, a, b)) -Base.vcat(a::AbstractArray, b::TrackedArray) = TrackedArray(Call(vcat, a, b)) +Base.vcat(a::TrackedMatrix, b::TrackedMatrix) = TrackedArray(Call(vcat, a, b)) +Base.vcat(a::TrackedMatrix, b::AbstractMatrix) = TrackedArray(Call(vcat, a, b)) +Base.vcat(a::AbstractMatrix, b::TrackedMatrix) = TrackedArray(Call(vcat, a, b)) function back!(::typeof(vcat), Δ, xs, ys) i = Base.tail(map(_ -> :, size(Δ)))