Support hcat and cat
This commit is contained in:
parent
13daaec1cb
commit
bcef5c4ab5
@ -96,6 +96,14 @@ Base.vcat(a::TrackedMatrix, b::TrackedMatrix...) = track(vcat, a, b...)
|
||||
Base.vcat(a::TrackedMatrix, b::AbstractMatrix) = track(vcat, a, b)
|
||||
Base.vcat(a::AbstractMatrix, b::TrackedMatrix) = track(vcat, a, b)
|
||||
|
||||
Base.hcat(a::A, b::B...) where {A <: TrackedArray, B <: TrackedArray} = track(hcat, a, b...)
|
||||
Base.hcat(a::A, b::B) where {A <: TrackedArray, B <: AbstractArray} = track(hcat, a, b)
|
||||
Base.hcat(a::A, b::B) where {A <: AbstractArray, B <: TrackedArray} = track(hcat, a, b)
|
||||
|
||||
Base.cat(dim::Int, a::A, b::B...) where {A <: TrackedArray, B <: TrackedArray} = track(cat, dim, a, b...)
|
||||
Base.cat(dim::Int, a::A, b::B) where {A <: TrackedArray, B <: AbstractArray} = track(cat, dim, a, b)
|
||||
Base.cat(dim::Int, a::A, b::B) where {A <: AbstractArray, B <: TrackedArray} = track(cat, dim, a, b)
|
||||
|
||||
function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1)
|
||||
Δ′ = similar(xs.data)
|
||||
S = size(xs.data)
|
||||
@ -117,6 +125,34 @@ function back(::typeof(vcat), Δ, xs...)
|
||||
end
|
||||
end
|
||||
|
||||
function back(::typeof(hcat), Δ, xs...)
|
||||
i = fill(:, ndims(Δ)-2)
|
||||
start = 0
|
||||
for xsi in xs
|
||||
if ndims(xsi) == 1
|
||||
@back(xsi, Δ[:, start+1])
|
||||
else
|
||||
@back(xsi, Δ[:, start+1:start+size(xsi,2), i...])
|
||||
end
|
||||
start += size(xsi, 2)
|
||||
end
|
||||
end
|
||||
|
||||
function back(::typeof(cat), Δ, dim, xs...)
|
||||
i = fill(:, dim-1)
|
||||
j = fill(:, ndims(Δ)-dim)
|
||||
start = 0
|
||||
for xsi in xs
|
||||
if ndims(xsi) < dim
|
||||
a = [fill(:, ndims(xsi)); ones(Int, dim-ndims(xsi)-1)]
|
||||
@back(xsi, Δ[a..., start+1])
|
||||
else
|
||||
@back(xsi, Δ[i..., start+1:start+size(xsi,dim), j...])
|
||||
end
|
||||
start += size(xsi, dim)
|
||||
end
|
||||
end
|
||||
|
||||
Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) = reshape(xs, dims)
|
||||
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Union{Int64,Colon}}}) = reshape(xs, Base._reshape_uncolon(xs, dims))
|
||||
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64}}) = track(reshape, xs, dims)
|
||||
|
Loading…
Reference in New Issue
Block a user