vcat with scalars

This commit is contained in:
Mike J Innes 2019-02-04 00:04:48 +00:00
parent 0469394715
commit 838070968e
2 changed files with 9 additions and 2 deletions

View File

@ -148,9 +148,9 @@ function combinations(xs, n)
[[x, c...] for x in xs, c in cs]
end
for i = 0:2, c = combinations([:AbstractArray, :TrackedArray], i), f = [:hcat, :vcat]
for i = 0:2, c = combinations([:AbstractArray, :TrackedArray, :Number], i), f = [:hcat, :vcat]
cnames = map(_ -> gensym(), c)
@eval Base.$f($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::TrackedArray, xs::AbstractArray...) =
@eval Base.$f($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::Union{TrackedArray,TrackedReal}, xs::Union{AbstractArray,Number}...) =
track($f, $(cnames...), x, xs...)
end

View File

@ -113,6 +113,13 @@ end
promotiontest((x...) -> cat(x..., dims = 3), rand(4,5,3), rand(4,5,1), rand(4,5,2))
end
@testset "scalars" begin
@test vcat(param([1, 2, 3]), 1) isa TrackedArray
@test vcat(1, param([1, 2, 3])) isa TrackedArray
@test hcat(1, param([1 2 3;])) isa TrackedArray
@test vcat(param(1), 2) isa TrackedArray
end
end
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))