cat promotions and mixed ranks
This commit is contained in:
parent
eaaf5fd34c
commit
509a2e59f6
@ -81,17 +81,18 @@ back(::typeof(ctranspose), Δ, xs) = @back(xs, trim(xs, Δ'))
|
|||||||
Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...)
|
Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...)
|
||||||
Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...)
|
Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...)
|
||||||
|
|
||||||
Base.vcat(a::A, b::B...) where {A <: TrackedArray, B <: TrackedArray} = track(vcat, a, b...)
|
for f in [:vcat, :hcat]
|
||||||
Base.vcat(a::A, b::B) where {A <: TrackedArray, B <: AbstractArray} = track(vcat, a, b)
|
@eval begin
|
||||||
Base.vcat(a::A, b::B) where {A <: AbstractArray, B <: TrackedArray} = track(vcat, a, b)
|
Base.$f(a::TrackedArray...) = track($f, a...)
|
||||||
|
Base.$f(a::TrackedArray, b::Array...) = track($f, a, b...)
|
||||||
|
|
||||||
Base.hcat(a::A, b::B...) where {A <: TrackedArray, B <: TrackedArray} = track(hcat, a, b...)
|
# assumes there is another function to capture Union{Matrix,Vector}... without any TrackedMatrix or TrackedVector
|
||||||
Base.hcat(a::A, b::B) where {A <: TrackedArray, B <: AbstractArray} = track(hcat, a, b)
|
Base.$f(a::Union{TrackedMatrix,TrackedVector,Matrix,Vector}...) = track($f, a...)
|
||||||
Base.hcat(a::A, b::B) where {A <: AbstractArray, B <: TrackedArray} = track(hcat, a, b)
|
end
|
||||||
|
end
|
||||||
|
|
||||||
Base.cat(dim::Int, a::A, b::B...) where {A <: TrackedArray, B <: TrackedArray} = track(cat, dim, a, b...)
|
Base.cat(dim::Int, a::TrackedArray...) = track(Base.cat, dim, a...)
|
||||||
Base.cat(dim::Int, a::A, b::B) where {A <: TrackedArray, B <: AbstractArray} = track(cat, dim, a, b)
|
Base.cat(dim::Int, a::TrackedArray, b::Array...) = track(Base.cat, dim, a, b...)
|
||||||
Base.cat(dim::Int, a::A, b::B) where {A <: AbstractArray, B <: TrackedArray} = track(cat, dim, a, b)
|
|
||||||
|
|
||||||
function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1)
|
function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1)
|
||||||
Δ′ = similar(xs.data)
|
Δ′ = similar(xs.data)
|
||||||
@ -106,21 +107,21 @@ function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function back(::typeof(vcat), Δ, xs...)
|
function back(::typeof(vcat), Δ, xs...)
|
||||||
i = fill(:, ndims(Δ)-1)
|
|
||||||
start = 0
|
start = 0
|
||||||
for xsi in xs
|
for xsi in xs
|
||||||
|
i = map(_ -> :, size(xsi)) |> Base.tail
|
||||||
@back(xsi, Δ[start+1:start+size(xsi,1), i...])
|
@back(xsi, Δ[start+1:start+size(xsi,1), i...])
|
||||||
start += size(xsi, 1)
|
start += size(xsi, 1)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function back(::typeof(hcat), Δ, xs...)
|
function back(::typeof(hcat), Δ, xs...)
|
||||||
i = fill(:, ndims(Δ)-2)
|
|
||||||
start = 0
|
start = 0
|
||||||
for xsi in xs
|
for xsi in xs
|
||||||
if ndims(xsi) == 1
|
if ndims(xsi) == 1
|
||||||
@back(xsi, Δ[:, start+1])
|
@back(xsi, Δ[:, start+1])
|
||||||
else
|
else
|
||||||
|
i = map(_ -> :, size(xsi)) |> Base.tail |> Base.tail
|
||||||
@back(xsi, Δ[:, start+1:start+size(xsi,2), i...])
|
@back(xsi, Δ[:, start+1:start+size(xsi,2), i...])
|
||||||
end
|
end
|
||||||
start += size(xsi, 2)
|
start += size(xsi, 2)
|
||||||
@ -128,14 +129,15 @@ function back(::typeof(hcat), Δ, xs...)
|
|||||||
end
|
end
|
||||||
|
|
||||||
function back(::typeof(cat), Δ, dim, xs...)
|
function back(::typeof(cat), Δ, dim, xs...)
|
||||||
i = fill(:, dim-1)
|
|
||||||
j = fill(:, ndims(Δ)-dim)
|
|
||||||
start = 0
|
start = 0
|
||||||
for xsi in xs
|
for xsi in xs
|
||||||
if ndims(xsi) < dim
|
if ndims(xsi) < dim
|
||||||
a = [fill(:, ndims(xsi)); ones(Int, dim-ndims(xsi)-1)]
|
i = map(_ -> :, size(xsi))
|
||||||
@back(xsi, Δ[a..., start+1])
|
j = ones(Int, dim-ndims(xsi)-1)
|
||||||
|
@back(xsi, Δ[i..., j..., start+1])
|
||||||
else
|
else
|
||||||
|
i = fill(:, dim-1)
|
||||||
|
j = fill(:, ndims(xsi)-dim)
|
||||||
@back(xsi, Δ[i..., start+1:start+size(xsi,dim), j...])
|
@back(xsi, Δ[i..., start+1:start+size(xsi,dim), j...])
|
||||||
end
|
end
|
||||||
start += size(xsi, dim)
|
start += size(xsi, dim)
|
||||||
|
@ -29,19 +29,42 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
|||||||
|
|
||||||
@test gradtest(x -> x', rand(5))
|
@test gradtest(x -> x', rand(5))
|
||||||
|
|
||||||
|
function simplepromotioncheck(f, A, B)
|
||||||
|
r0 = f(A, B)
|
||||||
|
r1 = f(param(A), B)
|
||||||
|
r2 = f(A, param(B))
|
||||||
|
r3 = f(param(A), param(B))
|
||||||
|
|
||||||
|
r1 == r2 == r3 && r0 == Flux.data(r1)
|
||||||
|
end
|
||||||
|
|
||||||
@testset "concat" begin
|
@testset "concat" begin
|
||||||
@testset "vcat $i" for (i,vcatf) in enumerate((vcat, (x...) -> cat(1, x...)))
|
cat1(x...) = cat(1, x...)
|
||||||
|
cat2(x...) = cat(2, x...)
|
||||||
|
|
||||||
|
@testset for vcatf in [vcat, cat1]
|
||||||
@test gradtest(vcatf, rand(5), rand(3))
|
@test gradtest(vcatf, rand(5), rand(3))
|
||||||
@test gradtest(vcatf, rand(5), rand(3), rand(8))
|
@test gradtest(vcatf, rand(5), rand(3), rand(8))
|
||||||
@test gradtest(vcatf, rand(5,2), rand(3,2), rand(8,2))
|
@test gradtest(vcatf, rand(5,2), rand(3,2), rand(8,2))
|
||||||
@test gradtest(vcatf, rand(5,2,3), rand(3,2,3), rand(8,2,3))
|
@test gradtest(vcatf, rand(5,2,3), rand(3,2,3), rand(8,2,3))
|
||||||
|
@test gradtest(vcatf, rand(5), rand(3,1))
|
||||||
|
@test gradtest(vcatf, rand(5)', rand(2,5))
|
||||||
end
|
end
|
||||||
@testset "hcat $i" for (i,hcatf) in enumerate((hcat, (x...) -> cat(2, x...)))
|
|
||||||
|
@test simplepromotioncheck(vcat, rand(5), rand(5))
|
||||||
|
|
||||||
|
@testset for hcatf in [hcat, cat2]
|
||||||
@test gradtest(hcatf, rand(5), rand(5))
|
@test gradtest(hcatf, rand(5), rand(5))
|
||||||
@test gradtest(hcatf, rand(2,5), rand(2,3), rand(2,8))
|
@test gradtest(hcatf, rand(2,5), rand(2,3), rand(2,8))
|
||||||
@test gradtest(hcatf, rand(2,5,3), rand(2,3,3), rand(2,8,3))
|
@test gradtest(hcatf, rand(2,5,3), rand(2,3,3), rand(2,8,3))
|
||||||
|
@test gradtest(hcatf, rand(5)', rand(1,3))
|
||||||
|
@test gradtest(hcatf, rand(5), rand(5,2))
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@test simplepromotioncheck(hcat, rand(5), rand(5))
|
||||||
|
|
||||||
@test gradtest((x...) -> cat(3, x...), rand(2,5,2), rand(2,5,3), rand(2,5,4))
|
@test gradtest((x...) -> cat(3, x...), rand(2,5,2), rand(2,5,3), rand(2,5,4))
|
||||||
|
|
||||||
@testset "cat($dim, ...)" for dim in 1:5
|
@testset "cat($dim, ...)" for dim in 1:5
|
||||||
catdim = (x...) -> cat(dim, x...)
|
catdim = (x...) -> cat(dim, x...)
|
||||||
@test gradtest(catdim, rand(5), rand(5))
|
@test gradtest(catdim, rand(5), rand(5))
|
||||||
|
Loading…
Reference in New Issue
Block a user