better h/vcat, fixes #378
This commit is contained in:
parent
8341d14427
commit
c3ca56f3ce
@ -110,30 +110,30 @@ Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...)
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
for f in [:vcat, :hcat]
|
function combinations(xs, n)
|
||||||
UArray = :(Union{TrackedArray,Vector,Matrix,Adjoint,Transpose})
|
n < 1 && return [[]]
|
||||||
@eval begin
|
cs = combinations(xs, n-1)
|
||||||
# This section is a bit of a hack since julia doesn't have a standardised
|
[[x, c...] for x in xs, c in cs]
|
||||||
# promotion mechanism for concatenation yet
|
end
|
||||||
# https://github.com/JuliaLang/julia/pull/20815
|
|
||||||
|
|
||||||
# It should support tracked concatenation with rank ∈ (1,2) with a
|
combinations([AbstractArray, TrackedArray], 2)
|
||||||
# TrackedArray anywhere among the arguments This works as long as base has
|
|
||||||
# other functions that captures `(::Union{Vector,RowVector,Matrix}...)`.
|
|
||||||
Base.$f(a::$UArray...) = track($f, a...)
|
|
||||||
|
|
||||||
# It should support tracked concatenation with rank>2 if the TrackedArray is
|
for i = 0:2, c = combinations([:AbstractArray, :TrackedArray], i), f = [:hcat, :vcat]
|
||||||
# first
|
cnames = map(_ -> gensym(), c)
|
||||||
Base.$f(a::TrackedArray, b::AbstractArray...) = track($f, a, b...)
|
@eval Base.$f($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::TrackedArray, xs::AbstractArray...) =
|
||||||
Base.$f(a::TrackedArray, b::$UArray...) = track($f, a, b...) # resolves ambiguity introduced by previous row
|
track($f, $(cnames...), x, xs...)
|
||||||
|
end
|
||||||
|
|
||||||
# It should support tracked concatenation with rank>2 if the TrackedArray is
|
for i = 0:2, c = combinations([:AbstractVecOrMat, :TrackedVecOrMat], i), f = [:hcat, :vcat]
|
||||||
# second
|
cnames = map(_ -> gensym(), c)
|
||||||
Base.$f(a::Array, b::TrackedArray, c::AbstractArray...) = track($f, a, b, c...)
|
@eval Base.$f($([:($x::$c{T}) for (x, c) in zip(cnames, c)]...), x::TrackedVecOrMat{T}, xs::AbstractVecOrMat{T}...) where T =
|
||||||
Base.$f(a::Union{Vector,Matrix,Adjoint,Transpose}, b::TrackedArray,
|
track($f, $(cnames...), x, xs...)
|
||||||
c::$UArray...) =
|
end
|
||||||
track($f, a, b, c...) # resolves ambiguity introduced by previous row
|
|
||||||
end
|
for i = 0:2, c = combinations([:AbstractVector, :TrackedVector], i), f = [:hcat, :vcat]
|
||||||
|
cnames = map(_ -> gensym(), c)
|
||||||
|
@eval Base.$f($([:($x::$c{T}) for (x, c) in zip(cnames, c)]...), x::TrackedVector{T}, xs::AbstractVector{T}...) where T =
|
||||||
|
track($f, $(cnames...), x, xs...)
|
||||||
end
|
end
|
||||||
|
|
||||||
@grad function vcat(xs...)
|
@grad function vcat(xs...)
|
||||||
@ -166,10 +166,11 @@ end
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
Base.cat(a::TrackedArray; dims) = track(cat, a, dims = dims)
|
for i = 0:2, c = combinations([:AbstractArray, :TrackedArray], i)
|
||||||
Base.cat(a::TrackedArray, b::TrackedArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims)
|
cnames = map(_ -> gensym(), c)
|
||||||
Base.cat(a::TrackedArray, b::AbstractArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims)
|
@eval Base.cat($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::TrackedArray, xs::AbstractArray...; dims) =
|
||||||
Base.cat(a::AbstractArray, b::TrackedArray, c::AbstractArray...; dims) = track(cat, a, b, c..., dims = dims)
|
track(cat, $(cnames...), x, xs..., dims = dims)
|
||||||
|
end
|
||||||
|
|
||||||
@grad function cat(Xs...; dims)
|
@grad function cat(Xs...; dims)
|
||||||
cat(data.(Xs)..., dims = dims), function (Δ)
|
cat(data.(Xs)..., dims = dims), function (Δ)
|
||||||
|
@ -37,12 +37,7 @@ function promotiontest(f, A, B, C)
|
|||||||
r0 = f(A, B, C)
|
r0 = f(A, B, C)
|
||||||
r1 = f(param(A), B, C)
|
r1 = f(param(A), B, C)
|
||||||
r2 = f(A, param(B), C)
|
r2 = f(A, param(B), C)
|
||||||
if all(ndims.((A,B,C)) .≤ 2) && f ∈ [hcat, vcat]
|
r3 = f(A, B, param(C))
|
||||||
r3 = f(A, B, param(C))
|
|
||||||
else
|
|
||||||
@test_throws MethodError f(A, B, param(C)) # until julia#20815 is resolved
|
|
||||||
r3 = r2
|
|
||||||
end
|
|
||||||
r4 = f(param(A), param(B), param(C))
|
r4 = f(param(A), param(B), param(C))
|
||||||
|
|
||||||
@test !isa(r0, TrackedArray)
|
@test !isa(r0, TrackedArray)
|
||||||
|
Loading…
Reference in New Issue
Block a user