RowVector tests

This commit is contained in:
Johan Gustafsson 2018-05-02 14:57:32 +02:00
parent 94bb064a0f
commit 5fc6190956
2 changed files with 49 additions and 18 deletions

View File

@ -95,10 +95,19 @@ end
for f in [:vcat, :hcat] for f in [:vcat, :hcat]
@eval begin @eval begin
Base.$f(a::TrackedArray...) = track($f, a...) # This section is a bit of a hack since julia doesn't have a standardised promotion mechanism for concatenation yet https://github.com/JuliaLang/julia/pull/20815
# assumes there are other functions to match the more conservative signature without TrackedArray; ie `Base.$f(::Union{Matrix,Vector,RowVector}...)` # It should support tracked concatenation with rank ∈ (1,2) with a TrackedArray anywhere among the arguments
Base.$f(a::Union{TrackedArray,Matrix,Vector,RowVector}...) = track($f, a...) # This works as long as base has other functions that captures `(::Union{Vector,RowVector,Matrix}...)`.
Base.$f(a::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a...)
# It should support tracked concatenation with rank>2 if the TrackedArray is first
Base.$f(a::TrackedArray, b::AbstractArray...) = track($f, a, b...)
Base.$f(a::TrackedArray, b::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a, b...) # resolves ambiguity introduced by previous row
# It should support tracked concatenation with rank>2 if the TrackedArray is second
Base.$f(a::Array, b::TrackedArray, c::AbstractArray...) = track($f, a, b, c...)
Base.$f(a::Union{Vector,RowVector,Matrix}, b::TrackedArray, c::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a, b, c...) # resolves ambiguity introduced by previous row
end end
end end
@ -124,9 +133,8 @@ function back(::typeof(hcat), Δ, xs...)
end end
end end
Base.cat(dims, a::TrackedArray...) = track(Base.cat, dims, a...) Base.cat(dims, a::TrackedArray, b::AbstractArray...) = track(cat, dims, a, b...)
Base.cat(dims, a::TrackedArray, b::Array...) = track(Base.cat, dims, a, b...) Base.cat(dims, a::Union{RowVector,Array}, b::TrackedArray, c::AbstractArray...) = track(cat, dims, a, b, c...)
Base.cat(dims, a::Array, b::TrackedArray...) = track(Base.cat, dims, a, b...)
function back(::typeof(cat), Δ, dims, Xs...) function back(::typeof(cat), Δ, dims, Xs...)
start = ntuple(i -> 0, Val{ndims(Δ)}) start = ntuple(i -> 0, Val{ndims(Δ)})

View File

@ -32,15 +32,19 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
function promotiontest(f, A, B, C) function promotiontest(f, A, B, C)
r0 = f(A, B, C) r0 = f(A, B, C)
r1 = f(param(A), B, C) r1 = f(param(A), B, C)
if ndims(A) <= 2 r2 = f(A, param(B), C)
r2 = f(A, param(B), C) if all(ndims.((A,B,C)) .≤ 2) && f [hcat, vcat]
r3 = f(A, B, param(C)) r3 = f(A, B, param(C))
else else
r2 = r3 = f(A, param(B), param(C)) @test_throws MethodError f(A, B, param(C)) # until julia#20815 is resolved
r3 = r2
end end
r4 = f(param(A), param(B), param(C)) r4 = f(param(A), param(B), param(C))
r1 == r2 == r3 == r4 && r0 == Flux.data(r4) @test !isa(r0, TrackedArray)
@test all(isa.([r1,r2,r3,r4], TrackedArray))
@test r1 == r2 == r3 == r4
@test r0 == Flux.data(r4)
end end
@testset "concat" begin @testset "concat" begin
@ -50,6 +54,7 @@ end
@testset for vcatf in [vcat, cat1] @testset for vcatf in [vcat, cat1]
@test gradtest(vcatf, rand(5), rand(3)) @test gradtest(vcatf, rand(5), rand(3))
@test gradtest(vcatf, rand(5), rand(3), rand(8)) @test gradtest(vcatf, rand(5), rand(3), rand(8))
@test gradtest(vcatf, rand(5)', rand(5)')
@test gradtest(vcatf, rand(5,2), rand(3,2), rand(8,2)) @test gradtest(vcatf, rand(5,2), rand(3,2), rand(8,2))
@test gradtest(vcatf, rand(5,2,3), rand(3,2,3), rand(8,2,3)) @test gradtest(vcatf, rand(5,2,3), rand(3,2,3), rand(8,2,3))
@test gradtest(vcatf, rand(5), rand(3,1)) @test gradtest(vcatf, rand(5), rand(3,1))
@ -58,31 +63,49 @@ end
@testset for hcatf in [hcat, cat2] @testset for hcatf in [hcat, cat2]
@test gradtest(hcatf, rand(5), rand(5)) @test gradtest(hcatf, rand(5), rand(5))
@test gradtest(hcatf, rand(5)', rand(5)')
@test gradtest(hcatf, rand(2,5), rand(2,3), rand(2,8)) @test gradtest(hcatf, rand(2,5), rand(2,3), rand(2,8))
@test gradtest(hcatf, rand(2,5,3), rand(2,3,3), rand(2,8,3)) @test gradtest(hcatf, rand(2,5,3), rand(2,3,3), rand(2,8,3))
@test gradtest(hcatf, rand(5), rand(5), rand(5,2)) @test gradtest(hcatf, rand(5), rand(5), rand(5,2))
@test gradtest(hcatf, rand(5)', rand(1,3)) @test gradtest(hcatf, rand(5)', rand(1,3))
@test gradtest(hcatf, rand(5), rand(5,2)) @test gradtest(hcatf, rand(5), rand(5,2))
end
@testset for catf in [vcat, cat1, hcat, cat2, (x...) -> cat(3, x...), (x...) -> cat((1,2), x...)]
@test gradtest(catf, rand(5))
@test gradtest(catf, rand(5)')
@test gradtest(catf, rand(2,5))
@test gradtest(catf, rand(2,5,3))
end end
@test gradtest((x...) -> cat(3, x...), rand(2,5,2), rand(2,5,3), rand(2,5,4)) @test gradtest((x...) -> cat(3, x...), rand(2,5,2), rand(2,5,3), rand(2,5,4))
@testset "cat($dim, ...)" for dim in 1:5 @testset "cat($dim, ...)" for dim in 3:5
catdim = (x...) -> cat(dim, x...) catdim = (x...) -> cat(dim, x...)
@test gradtest(catdim, rand(5), rand(5)) @test gradtest(catdim, rand(5), rand(5), rand(5))
@test gradtest(catdim, rand(2,5), rand(2,5), rand(2,5)) @test gradtest(catdim, rand(2,5), rand(2,5), rand(2,5))
@test gradtest(catdim, rand(2,5,3), rand(2,5,3), rand(2,5,3))
end end
@test !isa(vcat(rand(2)), TrackedArray)
@test !isa(hcat(rand(2)), TrackedArray)
@test !isa(cat(1,rand(2)), TrackedArray)
@test gradtest((a,b)->cat((2,3,5), a, b), rand(2,3), rand(2,4,2,1)) @test gradtest((a,b)->cat((2,3,5), a, b), rand(2,3), rand(2,4,2,1))
@testset "promotiontest" begin @testset "promotiontest" begin
@test promotiontest(vcat, rand(1,2), rand(2)', rand(2,2)) @testset for fcat in [hcat, vcat, (x...) -> cat(3, x...), (x...) -> cat((1,2), x...)]
@test promotiontest(hcat, rand(2,1), rand(2), rand(2,2)) promotiontest(fcat, rand(2), rand(2), rand(2))
@test promotiontest(vcat, rand(3,4,5), rand(1,4,5), rand(2,4,5)) promotiontest(fcat, rand(2)', rand(2)', rand(2)')
@test promotiontest(hcat, rand(4,3,5), rand(4,1,5), rand(4,2,5)) promotiontest(fcat, rand(2,2), rand(2,2), rand(2,2))
@testset "cat($dim, ...)" for dim in 1:5 promotiontest(fcat, rand(2,2,2), rand(2,2,2), rand(2,2,2))
@test promotiontest((x...) -> cat(dim, x...), rand(3,4,5), rand(3,4,5), rand(3,4,5))
end end
promotiontest(vcat, rand(1,2), rand(2)', rand(2,2))
promotiontest(hcat, rand(2,1), rand(2), rand(2,2))
promotiontest(vcat, rand(3,4,5), rand(1,4,5), rand(2,4,5))
promotiontest(hcat, rand(4,3,5), rand(4,1,5), rand(4,2,5))
promotiontest((x...) -> cat(3, x...), rand(4,5,3), rand(4,5,1), rand(4,5,2))
end end
end end