implement `back` of `repmat`

This commit is contained in:
chengchingwen 2018-03-07 20:40:00 +08:00
parent 261c6db371
commit c00f7f850f
2 changed files with 15 additions and 0 deletions

View File

@ -96,6 +96,18 @@ Base.vcat(a::TrackedMatrix, b::TrackedMatrix...) = track(vcat, a, b...)
Base.vcat(a::TrackedMatrix, b::AbstractMatrix) = track(vcat, a, b)
Base.vcat(a::AbstractMatrix, b::TrackedMatrix) = track(vcat, a, b)
function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1)
Δ′ = similar(xs.data)
S = size(xs.data)
for (i,v) in enumerate(Δ)
d1 = divrem(i-1, S[1]*m)
x = d1[2] % S[1]+1
y = d1[1] % S[2]+1
Δ′[x, y] += v
end
back(xs, Δ′)
end
function back(::typeof(vcat), Δ, xs...)
i = Base.tail(map(_ -> :, size(Δ)))
start = 0

View File

@ -32,6 +32,9 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest(vcat, rand(5,2), rand(3,2), rand(8,2))
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
@test gradtest(x -> repmat(x, 5,5), rand(4,5))
@test gradtest(x -> repmat(x, 5), rand(4,5))
@test gradtest(kron,rand(5), rand(3))
@test gradtest(kron, rand(5), rand(3), rand(8))
@test gradtest(kron,rand(5,1), rand(3,1))