Merge pull request #243 from gustafsson/catdim
Support for hcat and cat
This commit is contained in:
commit
24ad384a38
@ -81,21 +81,6 @@ back(::typeof(ctranspose), Δ, xs) = @back(xs, trim(xs, Δ'))
|
|||||||
Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...)
|
Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...)
|
||||||
Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...)
|
Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...)
|
||||||
|
|
||||||
Base.vcat(a::TrackedVector, b::TrackedVector) = track(vcat, a, b)
|
|
||||||
Base.vcat(a::TrackedVector, b::TrackedVector...) = track(vcat, a, b...)
|
|
||||||
Base.vcat(a::TrackedVector, b::AbstractVector) = track(vcat, a, b)
|
|
||||||
Base.vcat(a::AbstractVector, b::TrackedVector) = track(vcat, a, b)
|
|
||||||
|
|
||||||
Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat) = track(vcat, a, b)
|
|
||||||
Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat...) = track(vcat, a, b...)
|
|
||||||
Base.vcat(a::TrackedVecOrMat, b::AbstractVecOrMat) = track(vcat, a, b)
|
|
||||||
Base.vcat(a::AbstractVecOrMat, b::TrackedVecOrMat) = track(vcat, a, b)
|
|
||||||
|
|
||||||
Base.vcat(a::TrackedMatrix, b::TrackedMatrix) = track(vcat, a, b)
|
|
||||||
Base.vcat(a::TrackedMatrix, b::TrackedMatrix...) = track(vcat, a, b...)
|
|
||||||
Base.vcat(a::TrackedMatrix, b::AbstractMatrix) = track(vcat, a, b)
|
|
||||||
Base.vcat(a::AbstractMatrix, b::TrackedMatrix) = track(vcat, a, b)
|
|
||||||
|
|
||||||
function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1)
|
function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1)
|
||||||
Δ′ = similar(xs.data)
|
Δ′ = similar(xs.data)
|
||||||
S = size(xs.data)
|
S = size(xs.data)
|
||||||
@ -108,15 +93,70 @@ function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1)
|
|||||||
back(xs, Δ′)
|
back(xs, Δ′)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
for f in [:vcat, :hcat]
|
||||||
|
@eval begin
|
||||||
|
# This section is a bit of a hack since julia doesn't have a standardised
|
||||||
|
# promotion mechanism for concatenation yet
|
||||||
|
# https://github.com/JuliaLang/julia/pull/20815
|
||||||
|
|
||||||
|
# It should support tracked concatenation with rank ∈ (1,2) with a
|
||||||
|
# TrackedArray anywhere among the arguments This works as long as base has
|
||||||
|
# other functions that captures `(::Union{Vector,RowVector,Matrix}...)`.
|
||||||
|
Base.$f(a::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a...)
|
||||||
|
|
||||||
|
# It should support tracked concatenation with rank>2 if the TrackedArray is
|
||||||
|
# first
|
||||||
|
Base.$f(a::TrackedArray, b::AbstractArray...) = track($f, a, b...)
|
||||||
|
Base.$f(a::TrackedArray, b::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a, b...) # resolves ambiguity introduced by previous row
|
||||||
|
|
||||||
|
# It should support tracked concatenation with rank>2 if the TrackedArray is
|
||||||
|
# second
|
||||||
|
Base.$f(a::Array, b::TrackedArray, c::AbstractArray...) = track($f, a, b, c...)
|
||||||
|
Base.$f(a::Union{Vector,RowVector,Matrix}, b::TrackedArray,
|
||||||
|
c::Union{TrackedArray,Vector,RowVector,Matrix}...) =
|
||||||
|
track($f, a, b, c...) # resolves ambiguity introduced by previous row
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
function back(::typeof(vcat), Δ, xs...)
|
function back(::typeof(vcat), Δ, xs...)
|
||||||
i = Base.tail(map(_ -> :, size(Δ)))
|
|
||||||
start = 0
|
start = 0
|
||||||
for xsi in xs
|
for xsi in xs
|
||||||
|
i = map(_ -> :, size(xsi)) |> Base.tail
|
||||||
@back(xsi, Δ[start+1:start+size(xsi,1), i...])
|
@back(xsi, Δ[start+1:start+size(xsi,1), i...])
|
||||||
start += size(xsi, 1)
|
start += size(xsi, 1)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function back(::typeof(hcat), Δ, xs...)
|
||||||
|
start = 0
|
||||||
|
for xsi in xs
|
||||||
|
if ndims(xsi) == 1
|
||||||
|
@back(xsi, Δ[:, start+1])
|
||||||
|
else
|
||||||
|
i = map(_ -> :, size(xsi)) |> Base.tail |> Base.tail
|
||||||
|
@back(xsi, Δ[:, start+1:start+size(xsi,2), i...])
|
||||||
|
end
|
||||||
|
start += size(xsi, 2)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
Base.cat(dims, a::TrackedArray, b::AbstractArray...) = track(cat, dims, a, b...)
|
||||||
|
Base.cat(dims, a::Union{RowVector,Array}, b::TrackedArray, c::AbstractArray...) = track(cat, dims, a, b, c...)
|
||||||
|
|
||||||
|
function back(::typeof(cat), Δ, dims, Xs...)
|
||||||
|
start = ntuple(i -> 0, Val{ndims(Δ)})
|
||||||
|
for xs in Xs
|
||||||
|
dim_xs = 1:ndims(xs)
|
||||||
|
till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val{ndims(Δ)})
|
||||||
|
|
||||||
|
xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val{ndims(Δ)})
|
||||||
|
|
||||||
|
@back(xs, reshape(Δ[xs_in_Δ...],size(xs)))
|
||||||
|
|
||||||
|
start = start .+ till_xs
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) = reshape(xs, dims)
|
Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) = reshape(xs, dims)
|
||||||
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Union{Int64,Colon}}}) = reshape(xs, Base._reshape_uncolon(xs, dims))
|
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Union{Int64,Colon}}}) = reshape(xs, Base._reshape_uncolon(xs, dims))
|
||||||
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64}}) = track(reshape, xs, dims)
|
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64}}) = track(reshape, xs, dims)
|
||||||
|
@ -29,17 +29,94 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
|||||||
|
|
||||||
@test gradtest(x -> x', rand(5))
|
@test gradtest(x -> x', rand(5))
|
||||||
|
|
||||||
@test gradtest(vcat, rand(5), rand(3))
|
function promotiontest(f, A, B, C)
|
||||||
@test gradtest(vcat, rand(5), rand(3), rand(8))
|
r0 = f(A, B, C)
|
||||||
@test gradtest(vcat, rand(5,2), rand(3,2), rand(8,2))
|
r1 = f(param(A), B, C)
|
||||||
|
r2 = f(A, param(B), C)
|
||||||
|
if all(ndims.((A,B,C)) .≤ 2) && f ∈ [hcat, vcat]
|
||||||
|
r3 = f(A, B, param(C))
|
||||||
|
else
|
||||||
|
@test_throws MethodError f(A, B, param(C)) # until julia#20815 is resolved
|
||||||
|
r3 = r2
|
||||||
|
end
|
||||||
|
r4 = f(param(A), param(B), param(C))
|
||||||
|
|
||||||
|
@test !isa(r0, TrackedArray)
|
||||||
|
@test all(isa.([r1,r2,r3,r4], TrackedArray))
|
||||||
|
@test r1 == r2 == r3 == r4
|
||||||
|
@test r0 == Flux.data(r4)
|
||||||
|
end
|
||||||
|
|
||||||
|
@testset "concat" begin
|
||||||
|
cat1(x...) = cat(1, x...)
|
||||||
|
cat2(x...) = cat(2, x...)
|
||||||
|
|
||||||
|
@testset for vcatf in [vcat, cat1]
|
||||||
|
@test gradtest(vcatf, rand(5), rand(3))
|
||||||
|
@test gradtest(vcatf, rand(5), rand(3), rand(8))
|
||||||
|
@test gradtest(vcatf, rand(5)', rand(5)')
|
||||||
|
@test gradtest(vcatf, rand(5,2), rand(3,2), rand(8,2))
|
||||||
|
@test gradtest(vcatf, rand(5,2,3), rand(3,2,3), rand(8,2,3))
|
||||||
|
@test gradtest(vcatf, rand(5), rand(3,1))
|
||||||
|
@test gradtest(vcatf, rand(5)', rand(2,5))
|
||||||
|
end
|
||||||
|
|
||||||
|
@testset for hcatf in [hcat, cat2]
|
||||||
|
@test gradtest(hcatf, rand(5), rand(5))
|
||||||
|
@test gradtest(hcatf, rand(5)', rand(5)')
|
||||||
|
@test gradtest(hcatf, rand(2,5), rand(2,3), rand(2,8))
|
||||||
|
@test gradtest(hcatf, rand(2,5,3), rand(2,3,3), rand(2,8,3))
|
||||||
|
@test gradtest(hcatf, rand(5), rand(5), rand(5,2))
|
||||||
|
@test gradtest(hcatf, rand(5)', rand(1,3))
|
||||||
|
@test gradtest(hcatf, rand(5), rand(5,2))
|
||||||
|
end
|
||||||
|
|
||||||
|
@testset for catf in [vcat, cat1, hcat, cat2, (x...) -> cat(3, x...), (x...) -> cat((1,2), x...)]
|
||||||
|
@test gradtest(catf, rand(5))
|
||||||
|
@test gradtest(catf, rand(5)')
|
||||||
|
@test gradtest(catf, rand(2,5))
|
||||||
|
@test gradtest(catf, rand(2,5,3))
|
||||||
|
end
|
||||||
|
|
||||||
|
@test gradtest((x...) -> cat(3, x...), rand(2,5,2), rand(2,5,3), rand(2,5,4))
|
||||||
|
|
||||||
|
@testset "cat($dim, ...)" for dim in 3:5
|
||||||
|
catdim = (x...) -> cat(dim, x...)
|
||||||
|
@test gradtest(catdim, rand(5), rand(5), rand(5))
|
||||||
|
@test gradtest(catdim, rand(2,5), rand(2,5), rand(2,5))
|
||||||
|
@test gradtest(catdim, rand(2,5,3), rand(2,5,3), rand(2,5,3))
|
||||||
|
end
|
||||||
|
|
||||||
|
@test !isa(vcat(rand(2)), TrackedArray)
|
||||||
|
@test !isa(hcat(rand(2)), TrackedArray)
|
||||||
|
@test !isa(cat(1,rand(2)), TrackedArray)
|
||||||
|
|
||||||
|
@test gradtest((a,b)->cat((2,3,5), a, b), rand(2,3), rand(2,4,2,1))
|
||||||
|
|
||||||
|
@testset "promotiontest" begin
|
||||||
|
@testset for fcat in [hcat, vcat, (x...) -> cat(3, x...), (x...) -> cat((1,2), x...)]
|
||||||
|
promotiontest(fcat, rand(2), rand(2), rand(2))
|
||||||
|
promotiontest(fcat, rand(2)', rand(2)', rand(2)')
|
||||||
|
promotiontest(fcat, rand(2,2), rand(2,2), rand(2,2))
|
||||||
|
promotiontest(fcat, rand(2,2,2), rand(2,2,2), rand(2,2,2))
|
||||||
|
end
|
||||||
|
|
||||||
|
promotiontest(vcat, rand(1,2), rand(2)', rand(2,2))
|
||||||
|
promotiontest(hcat, rand(2,1), rand(2), rand(2,2))
|
||||||
|
promotiontest(vcat, rand(3,4,5), rand(1,4,5), rand(2,4,5))
|
||||||
|
promotiontest(hcat, rand(4,3,5), rand(4,1,5), rand(4,2,5))
|
||||||
|
promotiontest((x...) -> cat(3, x...), rand(4,5,3), rand(4,5,1), rand(4,5,2))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
|
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
|
||||||
|
|
||||||
@test gradtest(x -> repmat(x, 5,5), rand(4,5))
|
@test gradtest(x -> repmat(x, 5,5), rand(4,5))
|
||||||
@test gradtest(x -> repmat(x, 5), rand(4,5))
|
@test gradtest(x -> repmat(x, 5), rand(4,5))
|
||||||
|
|
||||||
@test gradtest(kron,rand(5), rand(3))
|
@test gradtest(kron, rand(5), rand(3))
|
||||||
@test gradtest(kron, rand(5), rand(3), rand(8))
|
@test gradtest(kron, rand(5), rand(3), rand(8))
|
||||||
@test gradtest(kron,rand(5,1), rand(3,1))
|
@test gradtest(kron, rand(5,1), rand(3,1))
|
||||||
@test gradtest(kron, rand(5,1), rand(3,1), rand(8,1))
|
@test gradtest(kron, rand(5,1), rand(3,1), rand(8,1))
|
||||||
@test gradtest(kron, rand(5,2), rand(3,2), rand(8,2))
|
@test gradtest(kron, rand(5,2), rand(3,2), rand(8,2))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user