implement `back` of `repmat`
This commit is contained in:
parent
261c6db371
commit
c00f7f850f
|
@ -96,6 +96,18 @@ Base.vcat(a::TrackedMatrix, b::TrackedMatrix...) = track(vcat, a, b...)
|
|||
Base.vcat(a::TrackedMatrix, b::AbstractMatrix) = track(vcat, a, b)
|
||||
Base.vcat(a::AbstractMatrix, b::TrackedMatrix) = track(vcat, a, b)
|
||||
|
||||
function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1)
|
||||
Δ′ = similar(xs.data)
|
||||
S = size(xs.data)
|
||||
for (i,v) in enumerate(Δ)
|
||||
d1 = divrem(i-1, S[1]*m)
|
||||
x = d1[2] % S[1]+1
|
||||
y = d1[1] % S[2]+1
|
||||
Δ′[x, y] += v
|
||||
end
|
||||
back(xs, Δ′)
|
||||
end
|
||||
|
||||
function back(::typeof(vcat), Δ, xs...)
|
||||
i = Base.tail(map(_ -> :, size(Δ)))
|
||||
start = 0
|
||||
|
|
|
@ -32,6 +32,9 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
|||
@test gradtest(vcat, rand(5,2), rand(3,2), rand(8,2))
|
||||
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
|
||||
|
||||
@test gradtest(x -> repmat(x, 5,5), rand(4,5))
|
||||
@test gradtest(x -> repmat(x, 5), rand(4,5))
|
||||
|
||||
@test gradtest(kron,rand(5), rand(3))
|
||||
@test gradtest(kron, rand(5), rand(3), rand(8))
|
||||
@test gradtest(kron,rand(5,1), rand(3,1))
|
||||
|
|
Loading…
Reference in New Issue