diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 202a2ca2..85dbdc41 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -87,11 +87,14 @@ Base.adjoint(xs::TrackedArray) = track(adjoint, xs) @grad transpose(xs) = transpose(data(xs)), Δ -> (reshape(transpose(Δ), size(xs)),) @grad adjoint(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),) + Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...) + @grad function repeat(xs; inner=ntuple(x->1, ndims(xs)), outer=ntuple(x->1, ndims(xs))) repeat(data(xs), inner = inner, outer = outer), function (Δ) Δ′ = zero(xs) S = size(xs) + # Loop through each element of Δ, calculate source dimensions, accumulate into Δ′ for (dest_idx, val) in pairs(IndexCartesian(), data(Δ)) # First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then @@ -102,6 +105,7 @@ Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...) (nobacksies(:repeat, Δ′),) end end + for f in [:vcat, :hcat] UArray = :(Union{TrackedArray,Vector,Matrix,Adjoint,Transpose}) @eval begin