fixed white lines
This commit is contained in:
parent
d933f2079b
commit
4860c1d48b
@ -87,11 +87,14 @@ Base.adjoint(xs::TrackedArray) = track(adjoint, xs)
|
|||||||
|
|
||||||
@grad transpose(xs) = transpose(data(xs)), Δ -> (reshape(transpose(Δ), size(xs)),)
|
@grad transpose(xs) = transpose(data(xs)), Δ -> (reshape(transpose(Δ), size(xs)),)
|
||||||
@grad adjoint(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),)
|
@grad adjoint(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),)
|
||||||
|
|
||||||
Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...)
|
Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...)
|
||||||
|
|
||||||
@grad function repeat(xs; inner=ntuple(x->1, ndims(xs)), outer=ntuple(x->1, ndims(xs)))
|
@grad function repeat(xs; inner=ntuple(x->1, ndims(xs)), outer=ntuple(x->1, ndims(xs)))
|
||||||
repeat(data(xs), inner = inner, outer = outer), function (Δ)
|
repeat(data(xs), inner = inner, outer = outer), function (Δ)
|
||||||
Δ′ = zero(xs)
|
Δ′ = zero(xs)
|
||||||
S = size(xs)
|
S = size(xs)
|
||||||
|
|
||||||
# Loop through each element of Δ, calculate source dimensions, accumulate into Δ′
|
# Loop through each element of Δ, calculate source dimensions, accumulate into Δ′
|
||||||
for (dest_idx, val) in pairs(IndexCartesian(), data(Δ))
|
for (dest_idx, val) in pairs(IndexCartesian(), data(Δ))
|
||||||
# First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then
|
# First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then
|
||||||
@ -102,6 +105,7 @@ Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...)
|
|||||||
(nobacksies(:repeat, Δ′),)
|
(nobacksies(:repeat, Δ′),)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
for f in [:vcat, :hcat]
|
for f in [:vcat, :hcat]
|
||||||
UArray = :(Union{TrackedArray,Vector,Matrix,Adjoint,Transpose})
|
UArray = :(Union{TrackedArray,Vector,Matrix,Adjoint,Transpose})
|
||||||
@eval begin
|
@eval begin
|
||||||
|
Loading…
Reference in New Issue
Block a user