cat with multiple dims #156
Co-authored-by: americast <sayan.sinha@iitkgp.ac.in>
This commit is contained in:
parent
fb68529169
commit
1c189c62ed
|
@ -98,7 +98,7 @@ for f in [:vcat, :hcat]
|
|||
Base.$f(a::TrackedArray...) = track($f, a...)
|
||||
Base.$f(a::TrackedArray, b::Array...) = track($f, a, b...)
|
||||
|
||||
# assumes there is another function to capture Union{Matrix,Vector}... without any TrackedMatrix or TrackedVector
|
||||
# assumes there is another function to match Union{Matrix,Vector}... without any TrackedMatrix or TrackedVector
|
||||
Base.$f(a::Union{TrackedMatrix,TrackedVector,Matrix,Vector}...) = track($f, a...)
|
||||
end
|
||||
end
|
||||
|
@ -125,22 +125,21 @@ function back(::typeof(hcat), Δ, xs...)
|
|||
end
|
||||
end
|
||||
|
||||
Base.cat(dim::Int, a::TrackedArray...) = track(Base.cat, dim, a...)
|
||||
Base.cat(dim::Int, a::TrackedArray, b::Array...) = track(Base.cat, dim, a, b...)
|
||||
Base.cat(dims, a::TrackedArray...) = track(Base.cat, dims, a...)
|
||||
Base.cat(dims, a::TrackedArray, b::Array...) = track(Base.cat, dims, a, b...)
|
||||
Base.cat(dims, a::Array, b::TrackedArray...) = track(Base.cat, dims, a, b...)
|
||||
|
||||
function back(::typeof(cat), Δ, dim, xs...)
|
||||
start = 0
|
||||
for xsi in xs
|
||||
if ndims(xsi) < dim
|
||||
i = map(_ -> :, size(xsi))
|
||||
j = ones(Int, dim-ndims(xsi)-1)
|
||||
@back(xsi, Δ[i..., j..., start+1])
|
||||
else
|
||||
i = fill(:, dim-1)
|
||||
j = fill(:, ndims(xsi)-dim)
|
||||
@back(xsi, Δ[i..., start+1:start+size(xsi,dim), j...])
|
||||
end
|
||||
start += size(xsi, dim)
|
||||
function back(::typeof(cat), Δ, dims, Xs...)
|
||||
start = ntuple(i -> 0, Val{ndims(Δ)})
|
||||
for xs in Xs
|
||||
dim_xs = 1:ndims(xs)
|
||||
till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val{ndims(Δ)})
|
||||
|
||||
xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val{ndims(Δ)})
|
||||
|
||||
@back(xs, reshape(Δ[xs_in_Δ...],size(xs)))
|
||||
|
||||
start = start .+ till_xs
|
||||
end
|
||||
end
|
||||
|
||||
|
|
|
@ -70,6 +70,8 @@ end
|
|||
@test gradtest(catdim, rand(5), rand(5))
|
||||
@test gradtest(catdim, rand(2,5), rand(2,5), rand(2,5))
|
||||
end
|
||||
|
||||
@test gradtest((a,b)->cat((2,3,5), a, b), rand(2,3), rand(2,4,2,1))
|
||||
end
|
||||
|
||||
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
|
||||
|
|
Loading…
Reference in New Issue