From bcef5c4ab512fd84ef44a1e97c465a24f1001977 Mon Sep 17 00:00:00 2001 From: Johan Gustafsson Date: Wed, 2 May 2018 15:56:08 +0200 Subject: [PATCH] Support hcat and cat --- src/tracker/array.jl | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 4dfb2c6d..0bfabf36 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -96,6 +96,14 @@ Base.vcat(a::TrackedMatrix, b::TrackedMatrix...) = track(vcat, a, b...) Base.vcat(a::TrackedMatrix, b::AbstractMatrix) = track(vcat, a, b) Base.vcat(a::AbstractMatrix, b::TrackedMatrix) = track(vcat, a, b) +Base.hcat(a::A, b::B...) where {A <: TrackedArray, B <: TrackedArray} = track(hcat, a, b...) +Base.hcat(a::A, b::B) where {A <: TrackedArray, B <: AbstractArray} = track(hcat, a, b) +Base.hcat(a::A, b::B) where {A <: AbstractArray, B <: TrackedArray} = track(hcat, a, b) + +Base.cat(dim::Int, a::A, b::B...) where {A <: TrackedArray, B <: TrackedArray} = track(cat, dim, a, b...) +Base.cat(dim::Int, a::A, b::B) where {A <: TrackedArray, B <: AbstractArray} = track(cat, dim, a, b) +Base.cat(dim::Int, a::A, b::B) where {A <: AbstractArray, B <: TrackedArray} = track(cat, dim, a, b) + function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1) Δ′ = similar(xs.data) S = size(xs.data) @@ -117,6 +125,34 @@ function back(::typeof(vcat), Δ, xs...) end end +function back(::typeof(hcat), Δ, xs...) + i = fill(:, ndims(Δ)-2) + start = 0 + for xsi in xs + if ndims(xsi) == 1 + @back(xsi, Δ[:, start+1]) + else + @back(xsi, Δ[:, start+1:start+size(xsi,2), i...]) + end + start += size(xsi, 2) + end +end + +function back(::typeof(cat), Δ, dim, xs...) + i = fill(:, dim-1) + j = fill(:, ndims(Δ)-dim) + start = 0 + for xsi in xs + if ndims(xsi) < dim + a = [fill(:, ndims(xsi)); ones(Int, dim-ndims(xsi)-1)] + @back(xsi, Δ[a..., start+1]) + else + @back(xsi, Δ[i..., start+1:start+size(xsi,dim), j...]) + end + start += size(xsi, dim) + end +end + Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) = reshape(xs, dims) Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Union{Int64,Colon}}}) = reshape(xs, Base._reshape_uncolon(xs, dims)) Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64}}) = track(reshape, xs, dims)