From c3ca56f3ce138c0f0bae9b5457d0752cc45592fc Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Wed, 19 Dec 2018 10:41:39 +0000 Subject: [PATCH] better h/vcat, fixes #378 --- src/tracker/array.jl | 51 ++++++++++++++++++++++---------------------- test/tracker.jl | 7 +----- 2 files changed, 27 insertions(+), 31 deletions(-) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 7d44ba0f..bb1dccd9 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -110,30 +110,30 @@ Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...) end end -for f in [:vcat, :hcat] - UArray = :(Union{TrackedArray,Vector,Matrix,Adjoint,Transpose}) - @eval begin - # This section is a bit of a hack since julia doesn't have a standardised - # promotion mechanism for concatenation yet - # https://github.com/JuliaLang/julia/pull/20815 +function combinations(xs, n) + n < 1 && return [[]] + cs = combinations(xs, n-1) + [[x, c...] for x in xs, c in cs] +end - # It should support tracked concatenation with rank ∈ (1,2) with a - # TrackedArray anywhere among the arguments This works as long as base has - # other functions that captures `(::Union{Vector,RowVector,Matrix}...)`. - Base.$f(a::$UArray...) = track($f, a...) +combinations([AbstractArray, TrackedArray], 2) - # It should support tracked concatenation with rank>2 if the TrackedArray is - # first - Base.$f(a::TrackedArray, b::AbstractArray...) = track($f, a, b...) - Base.$f(a::TrackedArray, b::$UArray...) = track($f, a, b...) # resolves ambiguity introduced by previous row +for i = 0:2, c = combinations([:AbstractArray, :TrackedArray], i), f = [:hcat, :vcat] + cnames = map(_ -> gensym(), c) + @eval Base.$f($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::TrackedArray, xs::AbstractArray...) = + track($f, $(cnames...), x, xs...) +end - # It should support tracked concatenation with rank>2 if the TrackedArray is - # second - Base.$f(a::Array, b::TrackedArray, c::AbstractArray...) = track($f, a, b, c...) - Base.$f(a::Union{Vector,Matrix,Adjoint,Transpose}, b::TrackedArray, - c::$UArray...) = - track($f, a, b, c...) # resolves ambiguity introduced by previous row - end +for i = 0:2, c = combinations([:AbstractVecOrMat, :TrackedVecOrMat], i), f = [:hcat, :vcat] + cnames = map(_ -> gensym(), c) + @eval Base.$f($([:($x::$c{T}) for (x, c) in zip(cnames, c)]...), x::TrackedVecOrMat{T}, xs::AbstractVecOrMat{T}...) where T = + track($f, $(cnames...), x, xs...) +end + +for i = 0:2, c = combinations([:AbstractVector, :TrackedVector], i), f = [:hcat, :vcat] + cnames = map(_ -> gensym(), c) + @eval Base.$f($([:($x::$c{T}) for (x, c) in zip(cnames, c)]...), x::TrackedVector{T}, xs::AbstractVector{T}...) where T = + track($f, $(cnames...), x, xs...) end @grad function vcat(xs...) @@ -166,10 +166,11 @@ end end end -Base.cat(a::TrackedArray; dims) = track(cat, a, dims = dims) -Base.cat(a::TrackedArray, b::TrackedArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims) -Base.cat(a::TrackedArray, b::AbstractArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims) -Base.cat(a::AbstractArray, b::TrackedArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims) +for i = 0:2, c = combinations([:AbstractArray, :TrackedArray], i) + cnames = map(_ -> gensym(), c) + @eval Base.cat($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::TrackedArray, xs::AbstractArray...; dims) = + track(cat, $(cnames...), x, xs..., dims = dims) +end @grad function cat(Xs...; dims) cat(data.(Xs)..., dims = dims), function (Δ) diff --git a/test/tracker.jl b/test/tracker.jl index c868f302..994d74a0 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -37,12 +37,7 @@ function promotiontest(f, A, B, C) r0 = f(A, B, C) r1 = f(param(A), B, C) r2 = f(A, param(B), C) - if all(ndims.((A,B,C)) .≤ 2) && f ∈ [hcat, vcat] - r3 = f(A, B, param(C)) - else - @test_throws MethodError f(A, B, param(C)) # until julia#20815 is resolved - r3 = r2 - end + r3 = f(A, B, param(C)) r4 = f(param(A), param(B), param(C)) @test !isa(r0, TrackedArray)