Merge pull request #270 from staticfloat/sf/tracked_repeat
Add `TrackedArray` support for `repeat(x; inner, outer)`
This commit is contained in:
commit
af8f3348eb
|
@ -93,6 +93,26 @@ function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1)
|
|||
back(xs, Δ′)
|
||||
end
|
||||
|
||||
|
||||
_repeat(A, inner, outer) = Base.repeat(A; inner=inner, outer=outer)
|
||||
Base.repeat(A::TrackedArray; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A))) = track(_repeat, A, inner, outer)
|
||||
|
||||
function back(::typeof(_repeat), Δ, xs::TrackedArray, inner, outer)
|
||||
Δ′ = similar(xs.data)
|
||||
Δ′ .= 0
|
||||
S = size(xs.data)
|
||||
|
||||
# Loop through each element of Δ, calculate source dimensions, accumulate into Δ′
|
||||
for (dest_idx, val) in enumerate(IndexCartesian(), Δ)
|
||||
# First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then
|
||||
# wrap around based on original size S.
|
||||
src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)]
|
||||
Δ′[src_idx...] += val
|
||||
end
|
||||
back(xs, Δ′)
|
||||
end
|
||||
|
||||
|
||||
for f in [:vcat, :hcat]
|
||||
@eval begin
|
||||
# This section is a bit of a hack since julia doesn't have a standardised
|
||||
|
|
|
@ -114,6 +114,9 @@ end
|
|||
@test gradtest(x -> repmat(x, 5,5), rand(4,5))
|
||||
@test gradtest(x -> repmat(x, 5), rand(4,5))
|
||||
|
||||
@test gradtest(x -> repeat(x; inner=2, outer=3), rand(5))
|
||||
@test gradtest(x -> repeat(x; inner=(2,2,1), outer=(1,1,3)), rand(5,4,3))
|
||||
|
||||
@test gradtest(kron, rand(5), rand(3))
|
||||
@test gradtest(kron, rand(5), rand(3), rand(8))
|
||||
@test gradtest(kron, rand(5,1), rand(3,1))
|
||||
|
|
Loading…
Reference in New Issue