better h/vcat, fixes #378
This commit is contained in:
parent
cdfc97f7c6
commit
6b11c552f3
|
@ -136,30 +136,30 @@ Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...)
|
|||
end
|
||||
end
|
||||
|
||||
for f in [:vcat, :hcat]
|
||||
UArray = :(Union{TrackedArray,Vector,Matrix,Adjoint,Transpose})
|
||||
@eval begin
|
||||
# This section is a bit of a hack since julia doesn't have a standardised
|
||||
# promotion mechanism for concatenation yet
|
||||
# https://github.com/JuliaLang/julia/pull/20815
|
||||
function combinations(xs, n)
|
||||
n < 1 && return [[]]
|
||||
cs = combinations(xs, n-1)
|
||||
[[x, c...] for x in xs, c in cs]
|
||||
end
|
||||
|
||||
# It should support tracked concatenation with rank ∈ (1,2) with a
|
||||
# TrackedArray anywhere among the arguments This works as long as base has
|
||||
# other functions that captures `(::Union{Vector,RowVector,Matrix}...)`.
|
||||
Base.$f(a::$UArray...) = track($f, a...)
|
||||
combinations([AbstractArray, TrackedArray], 2)
|
||||
|
||||
# It should support tracked concatenation with rank>2 if the TrackedArray is
|
||||
# first
|
||||
Base.$f(a::TrackedArray, b::AbstractArray...) = track($f, a, b...)
|
||||
Base.$f(a::TrackedArray, b::$UArray...) = track($f, a, b...) # resolves ambiguity introduced by previous row
|
||||
for i = 0:2, c = combinations([:AbstractArray, :TrackedArray], i), f = [:hcat, :vcat]
|
||||
cnames = map(_ -> gensym(), c)
|
||||
@eval Base.$f($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::TrackedArray, xs::AbstractArray...) =
|
||||
track($f, $(cnames...), x, xs...)
|
||||
end
|
||||
|
||||
# It should support tracked concatenation with rank>2 if the TrackedArray is
|
||||
# second
|
||||
Base.$f(a::Array, b::TrackedArray, c::AbstractArray...) = track($f, a, b, c...)
|
||||
Base.$f(a::Union{Vector,Matrix,Adjoint,Transpose}, b::TrackedArray,
|
||||
c::$UArray...) =
|
||||
track($f, a, b, c...) # resolves ambiguity introduced by previous row
|
||||
end
|
||||
for i = 0:2, c = combinations([:AbstractVecOrMat, :TrackedVecOrMat], i), f = [:hcat, :vcat]
|
||||
cnames = map(_ -> gensym(), c)
|
||||
@eval Base.$f($([:($x::$c{T}) for (x, c) in zip(cnames, c)]...), x::TrackedVecOrMat{T}, xs::AbstractVecOrMat{T}...) where T =
|
||||
track($f, $(cnames...), x, xs...)
|
||||
end
|
||||
|
||||
for i = 0:2, c = combinations([:AbstractVector, :TrackedVector], i), f = [:hcat, :vcat]
|
||||
cnames = map(_ -> gensym(), c)
|
||||
@eval Base.$f($([:($x::$c{T}) for (x, c) in zip(cnames, c)]...), x::TrackedVector{T}, xs::AbstractVector{T}...) where T =
|
||||
track($f, $(cnames...), x, xs...)
|
||||
end
|
||||
|
||||
@grad function vcat(xs...)
|
||||
|
@ -192,10 +192,11 @@ end
|
|||
end
|
||||
end
|
||||
|
||||
Base.cat(a::TrackedArray; dims) = track(cat, a, dims = dims)
|
||||
Base.cat(a::TrackedArray, b::TrackedArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims)
|
||||
Base.cat(a::TrackedArray, b::AbstractArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims)
|
||||
Base.cat(a::AbstractArray, b::TrackedArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims)
|
||||
for i = 0:2, c = combinations([:AbstractArray, :TrackedArray], i)
|
||||
cnames = map(_ -> gensym(), c)
|
||||
@eval Base.cat($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::TrackedArray, xs::AbstractArray...; dims) =
|
||||
track(cat, $(cnames...), x, xs..., dims = dims)
|
||||
end
|
||||
|
||||
@grad function cat(Xs...; dims)
|
||||
cat(data.(Xs)..., dims = dims), function (Δ)
|
||||
|
|
|
@ -42,12 +42,7 @@ function promotiontest(f, A, B, C)
|
|||
r0 = f(A, B, C)
|
||||
r1 = f(param(A), B, C)
|
||||
r2 = f(A, param(B), C)
|
||||
if all(ndims.((A,B,C)) .≤ 2) && f ∈ [hcat, vcat]
|
||||
r3 = f(A, B, param(C))
|
||||
else
|
||||
@test_throws MethodError f(A, B, param(C)) # until julia#20815 is resolved
|
||||
r3 = r2
|
||||
end
|
||||
r3 = f(A, B, param(C))
|
||||
r4 = f(param(A), param(B), param(C))
|
||||
|
||||
@test !isa(r0, TrackedArray)
|
||||
|
|
Loading…
Reference in New Issue