cat with multiple dims #156

Co-authored-by: americast <sayan.sinha@iitkgp.ac.in>
This commit is contained in:
Johan Gustafsson 2018-05-02 09:03:54 +02:00
parent fb68529169
commit 1c189c62ed
2 changed files with 17 additions and 16 deletions

View File

@ -98,7 +98,7 @@ for f in [:vcat, :hcat]
Base.$f(a::TrackedArray...) = track($f, a...)
Base.$f(a::TrackedArray, b::Array...) = track($f, a, b...)
# assumes there is another function to capture Union{Matrix,Vector}... without any TrackedMatrix or TrackedVector
# assumes there is another function to match Union{Matrix,Vector}... without any TrackedMatrix or TrackedVector
Base.$f(a::Union{TrackedMatrix,TrackedVector,Matrix,Vector}...) = track($f, a...)
end
end
@ -125,22 +125,21 @@ function back(::typeof(hcat), Δ, xs...)
end
end
Base.cat(dim::Int, a::TrackedArray...) = track(Base.cat, dim, a...)
Base.cat(dim::Int, a::TrackedArray, b::Array...) = track(Base.cat, dim, a, b...)
Base.cat(dims, a::TrackedArray...) = track(Base.cat, dims, a...)
Base.cat(dims, a::TrackedArray, b::Array...) = track(Base.cat, dims, a, b...)
Base.cat(dims, a::Array, b::TrackedArray...) = track(Base.cat, dims, a, b...)
function back(::typeof(cat), Δ, dim, xs...)
start = 0
for xsi in xs
if ndims(xsi) < dim
i = map(_ -> :, size(xsi))
j = ones(Int, dim-ndims(xsi)-1)
@back(xsi, Δ[i..., j..., start+1])
else
i = fill(:, dim-1)
j = fill(:, ndims(xsi)-dim)
@back(xsi, Δ[i..., start+1:start+size(xsi,dim), j...])
end
start += size(xsi, dim)
function back(::typeof(cat), Δ, dims, Xs...)
start = ntuple(i -> 0, Val{ndims(Δ)})
for xs in Xs
dim_xs = 1:ndims(xs)
till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val{ndims(Δ)})
xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val{ndims(Δ)})
@back(xs, reshape(Δ[xs_in_Δ...],size(xs)))
start = start .+ till_xs
end
end

View File

@ -70,6 +70,8 @@ end
@test gradtest(catdim, rand(5), rand(5))
@test gradtest(catdim, rand(2,5), rand(2,5), rand(2,5))
end
@test gradtest((a,b)->cat((2,3,5), a, b), rand(2,3), rand(2,4,2,1))
end
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))