diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index e76f2feb..9510febc 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -26,6 +26,7 @@ end TrackedScalar{T,A} = TrackedArray{T,0,A} TrackedVector{T,A} = TrackedArray{T,1,A} TrackedMatrix{T,A} = TrackedArray{T,2,A} +TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}} TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray = TrackedArray{eltype(A),ndims(A),A}(c, x, Δ) diff --git a/src/tracker/lib.jl b/src/tracker/lib.jl index 03046e87..3ac6ddde 100644 --- a/src/tracker/lib.jl +++ b/src/tracker/lib.jl @@ -25,6 +25,9 @@ Base.ctranspose(xs::TrackedArray) = TrackedArray(Call(ctranspose, xs)) back!(::typeof(transpose), Δ, xs) = @back!(xs, trim(xs, Δ.')) back!(::typeof(ctranspose), Δ, xs) = @back!(xs, trim(xs, Δ')) +Base.repmat(x::TrackedVecOrMat, a::Integer...) = TrackedArray(Call(repmat, x, a...)) +Base.repmat(x::TrackedVecOrMat, a::Int64...) = TrackedArray(Call(repmat, x, a...)) + Base.vcat(a::TrackedVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b)) Base.vcat(a::TrackedVector, b::AbstractVector) = TrackedArray(Call(vcat, a, b)) Base.vcat(a::AbstractVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b))