From 363caeddc605beacc4dca41812192ce099fab637 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 5 Sep 2017 02:12:53 -0400 Subject: [PATCH] repmat forward --- src/tracker/Tracker.jl | 1 + src/tracker/lib.jl | 3 +++ 2 files changed, 4 insertions(+) diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index e76f2feb..9510febc 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -26,6 +26,7 @@ end TrackedScalar{T,A} = TrackedArray{T,0,A} TrackedVector{T,A} = TrackedArray{T,1,A} TrackedMatrix{T,A} = TrackedArray{T,2,A} +TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}} TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray = TrackedArray{eltype(A),ndims(A),A}(c, x, Δ) diff --git a/src/tracker/lib.jl b/src/tracker/lib.jl index 03046e87..3ac6ddde 100644 --- a/src/tracker/lib.jl +++ b/src/tracker/lib.jl @@ -25,6 +25,9 @@ Base.ctranspose(xs::TrackedArray) = TrackedArray(Call(ctranspose, xs)) back!(::typeof(transpose), Δ, xs) = @back!(xs, trim(xs, Δ.')) back!(::typeof(ctranspose), Δ, xs) = @back!(xs, trim(xs, Δ')) +Base.repmat(x::TrackedVecOrMat, a::Integer...) = TrackedArray(Call(repmat, x, a...)) +Base.repmat(x::TrackedVecOrMat, a::Int64...) = TrackedArray(Call(repmat, x, a...)) + Base.vcat(a::TrackedVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b)) Base.vcat(a::TrackedVector, b::AbstractVector) = TrackedArray(Call(vcat, a, b)) Base.vcat(a::AbstractVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b))