vcat with scalars
This commit is contained in:
parent
0469394715
commit
838070968e
|
@ -148,9 +148,9 @@ function combinations(xs, n)
|
|||
[[x, c...] for x in xs, c in cs]
|
||||
end
|
||||
|
||||
for i = 0:2, c = combinations([:AbstractArray, :TrackedArray], i), f = [:hcat, :vcat]
|
||||
for i = 0:2, c = combinations([:AbstractArray, :TrackedArray, :Number], i), f = [:hcat, :vcat]
|
||||
cnames = map(_ -> gensym(), c)
|
||||
@eval Base.$f($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::TrackedArray, xs::AbstractArray...) =
|
||||
@eval Base.$f($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::Union{TrackedArray,TrackedReal}, xs::Union{AbstractArray,Number}...) =
|
||||
track($f, $(cnames...), x, xs...)
|
||||
end
|
||||
|
||||
|
|
|
@ -113,6 +113,13 @@ end
|
|||
promotiontest((x...) -> cat(x..., dims = 3), rand(4,5,3), rand(4,5,1), rand(4,5,2))
|
||||
end
|
||||
|
||||
@testset "scalars" begin
|
||||
@test vcat(param([1, 2, 3]), 1) isa TrackedArray
|
||||
@test vcat(1, param([1, 2, 3])) isa TrackedArray
|
||||
@test hcat(1, param([1 2 3;])) isa TrackedArray
|
||||
@test vcat(param(1), 2) isa TrackedArray
|
||||
end
|
||||
|
||||
end
|
||||
|
||||
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
|
||||
|
|
Loading…
Reference in New Issue