more tests of array promotion for concatenation

# Conflicts:
#	test/tracker.jl
This commit is contained in:
Johan Gustafsson 2018-05-02 15:47:30 +02:00
parent cfdb16e609
commit 94bb064a0f
2 changed files with 22 additions and 16 deletions

View File

@ -96,10 +96,9 @@ end
for f in [:vcat, :hcat]
@eval begin
Base.$f(a::TrackedArray...) = track($f, a...)
Base.$f(a::TrackedArray, b::Array...) = track($f, a, b...)
# assumes there is another function to match Union{Matrix,Vector}... without any TrackedMatrix or TrackedVector
Base.$f(a::Union{TrackedMatrix,TrackedVector,Matrix,Vector}...) = track($f, a...)
# assumes there are other functions to match the more conservative signature without TrackedArray; ie `Base.$f(::Union{Matrix,Vector,RowVector}...)`
Base.$f(a::Union{TrackedArray,Matrix,Vector,RowVector}...) = track($f, a...)
end
end

View File

@ -29,13 +29,18 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest(x -> x', rand(5))
function simplepromotioncheck(f, A, B)
r0 = f(A, B)
r1 = f(param(A), B)
r2 = f(A, param(B))
r3 = f(param(A), param(B))
function promotiontest(f, A, B, C)
r0 = f(A, B, C)
r1 = f(param(A), B, C)
if ndims(A) <= 2
r2 = f(A, param(B), C)
r3 = f(A, B, param(C))
else
r2 = r3 = f(A, param(B), param(C))
end
r4 = f(param(A), param(B), param(C))
r1 == r2 == r3 && r0 == Flux.data(r1)
r1 == r2 == r3 == r4 && r0 == Flux.data(r4)
end
@testset "concat" begin
@ -51,18 +56,15 @@ end
@test gradtest(vcatf, rand(5)', rand(2,5))
end
@test simplepromotioncheck(vcat, rand(5), rand(5))
@testset for hcatf in [hcat, cat2]
@test gradtest(hcatf, rand(5), rand(5))
@test gradtest(hcatf, rand(2,5), rand(2,3), rand(2,8))
@test gradtest(hcatf, rand(2,5,3), rand(2,3,3), rand(2,8,3))
@test gradtest(hcatf, rand(5), rand(5), rand(5,2))
@test gradtest(hcatf, rand(5)', rand(1,3))
@test gradtest(hcatf, rand(5), rand(5,2))
end
@test simplepromotioncheck(hcat, rand(5), rand(5))
@test gradtest((x...) -> cat(3, x...), rand(2,5,2), rand(2,5,3), rand(2,5,4))
@testset "cat($dim, ...)" for dim in 1:5
@ -73,9 +75,14 @@ end
@test gradtest((a,b)->cat((2,3,5), a, b), rand(2,3), rand(2,4,2,1))
@testset "issue #213" begin
A, B, C = rand(2,2), rand(2,2), rand(2,2)
@test vcat(A, B, C |> param) == vcat(param.((A,B,C))...)
@testset "promotiontest" begin
@test promotiontest(vcat, rand(1,2), rand(2)', rand(2,2))
@test promotiontest(hcat, rand(2,1), rand(2), rand(2,2))
@test promotiontest(vcat, rand(3,4,5), rand(1,4,5), rand(2,4,5))
@test promotiontest(hcat, rand(4,3,5), rand(4,1,5), rand(4,2,5))
@testset "cat($dim, ...)" for dim in 1:5
@test promotiontest((x...) -> cat(dim, x...), rand(3,4,5), rand(3,4,5), rand(3,4,5))
end
end
end