more tests of array promotion for concatenation
# Conflicts: # test/tracker.jl
This commit is contained in:
parent
cfdb16e609
commit
94bb064a0f
|
@ -96,10 +96,9 @@ end
|
|||
for f in [:vcat, :hcat]
|
||||
@eval begin
|
||||
Base.$f(a::TrackedArray...) = track($f, a...)
|
||||
Base.$f(a::TrackedArray, b::Array...) = track($f, a, b...)
|
||||
|
||||
# assumes there is another function to match Union{Matrix,Vector}... without any TrackedMatrix or TrackedVector
|
||||
Base.$f(a::Union{TrackedMatrix,TrackedVector,Matrix,Vector}...) = track($f, a...)
|
||||
# assumes there are other functions to match the more conservative signature without TrackedArray; ie `Base.$f(::Union{Matrix,Vector,RowVector}...)`
|
||||
Base.$f(a::Union{TrackedArray,Matrix,Vector,RowVector}...) = track($f, a...)
|
||||
end
|
||||
end
|
||||
|
||||
|
|
|
@ -29,13 +29,18 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
|||
|
||||
@test gradtest(x -> x', rand(5))
|
||||
|
||||
function simplepromotioncheck(f, A, B)
|
||||
r0 = f(A, B)
|
||||
r1 = f(param(A), B)
|
||||
r2 = f(A, param(B))
|
||||
r3 = f(param(A), param(B))
|
||||
function promotiontest(f, A, B, C)
|
||||
r0 = f(A, B, C)
|
||||
r1 = f(param(A), B, C)
|
||||
if ndims(A) <= 2
|
||||
r2 = f(A, param(B), C)
|
||||
r3 = f(A, B, param(C))
|
||||
else
|
||||
r2 = r3 = f(A, param(B), param(C))
|
||||
end
|
||||
r4 = f(param(A), param(B), param(C))
|
||||
|
||||
r1 == r2 == r3 && r0 == Flux.data(r1)
|
||||
r1 == r2 == r3 == r4 && r0 == Flux.data(r4)
|
||||
end
|
||||
|
||||
@testset "concat" begin
|
||||
|
@ -51,18 +56,15 @@ end
|
|||
@test gradtest(vcatf, rand(5)', rand(2,5))
|
||||
end
|
||||
|
||||
@test simplepromotioncheck(vcat, rand(5), rand(5))
|
||||
|
||||
@testset for hcatf in [hcat, cat2]
|
||||
@test gradtest(hcatf, rand(5), rand(5))
|
||||
@test gradtest(hcatf, rand(2,5), rand(2,3), rand(2,8))
|
||||
@test gradtest(hcatf, rand(2,5,3), rand(2,3,3), rand(2,8,3))
|
||||
@test gradtest(hcatf, rand(5), rand(5), rand(5,2))
|
||||
@test gradtest(hcatf, rand(5)', rand(1,3))
|
||||
@test gradtest(hcatf, rand(5), rand(5,2))
|
||||
end
|
||||
|
||||
@test simplepromotioncheck(hcat, rand(5), rand(5))
|
||||
|
||||
@test gradtest((x...) -> cat(3, x...), rand(2,5,2), rand(2,5,3), rand(2,5,4))
|
||||
|
||||
@testset "cat($dim, ...)" for dim in 1:5
|
||||
|
@ -73,9 +75,14 @@ end
|
|||
|
||||
@test gradtest((a,b)->cat((2,3,5), a, b), rand(2,3), rand(2,4,2,1))
|
||||
|
||||
@testset "issue #213" begin
|
||||
A, B, C = rand(2,2), rand(2,2), rand(2,2)
|
||||
@test vcat(A, B, C |> param) == vcat(param.((A,B,C))...)
|
||||
@testset "promotiontest" begin
|
||||
@test promotiontest(vcat, rand(1,2), rand(2)', rand(2,2))
|
||||
@test promotiontest(hcat, rand(2,1), rand(2), rand(2,2))
|
||||
@test promotiontest(vcat, rand(3,4,5), rand(1,4,5), rand(2,4,5))
|
||||
@test promotiontest(hcat, rand(4,3,5), rand(4,1,5), rand(4,2,5))
|
||||
@testset "cat($dim, ...)" for dim in 1:5
|
||||
@test promotiontest((x...) -> cat(dim, x...), rand(3,4,5), rand(3,4,5), rand(3,4,5))
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
|
|
Loading…
Reference in New Issue