better h/vcat, fixes #378

This commit is contained in:
Mike J Innes 2018-12-19 10:41:39 +00:00
parent 8341d14427
commit c3ca56f3ce
2 changed files with 27 additions and 31 deletions

View File

@ -110,30 +110,30 @@ Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...)
end end
end end
for f in [:vcat, :hcat] function combinations(xs, n)
UArray = :(Union{TrackedArray,Vector,Matrix,Adjoint,Transpose}) n < 1 && return [[]]
@eval begin cs = combinations(xs, n-1)
# This section is a bit of a hack since julia doesn't have a standardised [[x, c...] for x in xs, c in cs]
# promotion mechanism for concatenation yet end
# https://github.com/JuliaLang/julia/pull/20815
# It should support tracked concatenation with rank ∈ (1,2) with a combinations([AbstractArray, TrackedArray], 2)
# TrackedArray anywhere among the arguments This works as long as base has
# other functions that captures `(::Union{Vector,RowVector,Matrix}...)`.
Base.$f(a::$UArray...) = track($f, a...)
# It should support tracked concatenation with rank>2 if the TrackedArray is for i = 0:2, c = combinations([:AbstractArray, :TrackedArray], i), f = [:hcat, :vcat]
# first cnames = map(_ -> gensym(), c)
Base.$f(a::TrackedArray, b::AbstractArray...) = track($f, a, b...) @eval Base.$f($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::TrackedArray, xs::AbstractArray...) =
Base.$f(a::TrackedArray, b::$UArray...) = track($f, a, b...) # resolves ambiguity introduced by previous row track($f, $(cnames...), x, xs...)
end
# It should support tracked concatenation with rank>2 if the TrackedArray is for i = 0:2, c = combinations([:AbstractVecOrMat, :TrackedVecOrMat], i), f = [:hcat, :vcat]
# second cnames = map(_ -> gensym(), c)
Base.$f(a::Array, b::TrackedArray, c::AbstractArray...) = track($f, a, b, c...) @eval Base.$f($([:($x::$c{T}) for (x, c) in zip(cnames, c)]...), x::TrackedVecOrMat{T}, xs::AbstractVecOrMat{T}...) where T =
Base.$f(a::Union{Vector,Matrix,Adjoint,Transpose}, b::TrackedArray, track($f, $(cnames...), x, xs...)
c::$UArray...) = end
track($f, a, b, c...) # resolves ambiguity introduced by previous row
end for i = 0:2, c = combinations([:AbstractVector, :TrackedVector], i), f = [:hcat, :vcat]
cnames = map(_ -> gensym(), c)
@eval Base.$f($([:($x::$c{T}) for (x, c) in zip(cnames, c)]...), x::TrackedVector{T}, xs::AbstractVector{T}...) where T =
track($f, $(cnames...), x, xs...)
end end
@grad function vcat(xs...) @grad function vcat(xs...)
@ -166,10 +166,11 @@ end
end end
end end
Base.cat(a::TrackedArray; dims) = track(cat, a, dims = dims) for i = 0:2, c = combinations([:AbstractArray, :TrackedArray], i)
Base.cat(a::TrackedArray, b::TrackedArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims) cnames = map(_ -> gensym(), c)
Base.cat(a::TrackedArray, b::AbstractArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims) @eval Base.cat($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::TrackedArray, xs::AbstractArray...; dims) =
Base.cat(a::AbstractArray, b::TrackedArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims) track(cat, $(cnames...), x, xs..., dims = dims)
end
@grad function cat(Xs...; dims) @grad function cat(Xs...; dims)
cat(data.(Xs)..., dims = dims), function (Δ) cat(data.(Xs)..., dims = dims), function (Δ)

View File

@ -37,12 +37,7 @@ function promotiontest(f, A, B, C)
r0 = f(A, B, C) r0 = f(A, B, C)
r1 = f(param(A), B, C) r1 = f(param(A), B, C)
r2 = f(A, param(B), C) r2 = f(A, param(B), C)
if all(ndims.((A,B,C)) .≤ 2) && f [hcat, vcat] r3 = f(A, B, param(C))
r3 = f(A, B, param(C))
else
@test_throws MethodError f(A, B, param(C)) # until julia#20815 is resolved
r3 = r2
end
r4 = f(param(A), param(B), param(C)) r4 = f(param(A), param(B), param(C))
@test !isa(r0, TrackedArray) @test !isa(r0, TrackedArray)