Support hcat and cat

This commit is contained in:
Johan Gustafsson 2018-05-02 15:56:08 +02:00
parent 13daaec1cb
commit bcef5c4ab5

View File

@ -96,6 +96,14 @@ Base.vcat(a::TrackedMatrix, b::TrackedMatrix...) = track(vcat, a, b...)
Base.vcat(a::TrackedMatrix, b::AbstractMatrix) = track(vcat, a, b)
Base.vcat(a::AbstractMatrix, b::TrackedMatrix) = track(vcat, a, b)
Base.hcat(a::A, b::B...) where {A <: TrackedArray, B <: TrackedArray} = track(hcat, a, b...)
Base.hcat(a::A, b::B) where {A <: TrackedArray, B <: AbstractArray} = track(hcat, a, b)
Base.hcat(a::A, b::B) where {A <: AbstractArray, B <: TrackedArray} = track(hcat, a, b)
Base.cat(dim::Int, a::A, b::B...) where {A <: TrackedArray, B <: TrackedArray} = track(cat, dim, a, b...)
Base.cat(dim::Int, a::A, b::B) where {A <: TrackedArray, B <: AbstractArray} = track(cat, dim, a, b)
Base.cat(dim::Int, a::A, b::B) where {A <: AbstractArray, B <: TrackedArray} = track(cat, dim, a, b)
function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1)
Δ′ = similar(xs.data)
S = size(xs.data)
@ -117,6 +125,34 @@ function back(::typeof(vcat), Δ, xs...)
end
end
function back(::typeof(hcat), Δ, xs...)
i = fill(:, ndims(Δ)-2)
start = 0
for xsi in xs
if ndims(xsi) == 1
@back(xsi, Δ[:, start+1])
else
@back(xsi, Δ[:, start+1:start+size(xsi,2), i...])
end
start += size(xsi, 2)
end
end
function back(::typeof(cat), Δ, dim, xs...)
i = fill(:, dim-1)
j = fill(:, ndims(Δ)-dim)
start = 0
for xsi in xs
if ndims(xsi) < dim
a = [fill(:, ndims(xsi)); ones(Int, dim-ndims(xsi)-1)]
@back(xsi, Δ[a..., start+1])
else
@back(xsi, Δ[i..., start+1:start+size(xsi,dim), j...])
end
start += size(xsi, dim)
end
end
Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) = reshape(xs, dims)
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Union{Int64,Colon}}}) = reshape(xs, Base._reshape_uncolon(xs, dims))
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64}}) = track(reshape, xs, dims)