fix cat
This commit is contained in:
parent
adc216f182
commit
00cfe24d66
|
@ -103,6 +103,7 @@ Base.repeat(A::TrackedArray; kw...) = track_kw(repeat, A; kw...)
|
|||
end
|
||||
|
||||
for f in [:vcat, :hcat]
|
||||
UArray = :(Union{TrackedArray,Vector,Matrix,Adjoint,Transpose})
|
||||
@eval begin
|
||||
# This section is a bit of a hack since julia doesn't have a standardised
|
||||
# promotion mechanism for concatenation yet
|
||||
|
@ -111,18 +112,18 @@ for f in [:vcat, :hcat]
|
|||
# It should support tracked concatenation with rank ∈ (1,2) with a
|
||||
# TrackedArray anywhere among the arguments This works as long as base has
|
||||
# other functions that captures `(::Union{Vector,RowVector,Matrix}...)`.
|
||||
Base.$f(a::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a...)
|
||||
Base.$f(a::$UArray...) = track($f, a...)
|
||||
|
||||
# It should support tracked concatenation with rank>2 if the TrackedArray is
|
||||
# first
|
||||
Base.$f(a::TrackedArray, b::AbstractArray...) = track($f, a, b...)
|
||||
Base.$f(a::TrackedArray, b::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a, b...) # resolves ambiguity introduced by previous row
|
||||
Base.$f(a::TrackedArray, b::$UArray...) = track($f, a, b...) # resolves ambiguity introduced by previous row
|
||||
|
||||
# It should support tracked concatenation with rank>2 if the TrackedArray is
|
||||
# second
|
||||
Base.$f(a::Array, b::TrackedArray, c::AbstractArray...) = track($f, a, b, c...)
|
||||
Base.$f(a::Union{Vector,RowVector,Matrix}, b::TrackedArray,
|
||||
c::Union{TrackedArray,Vector,RowVector,Matrix}...) =
|
||||
Base.$f(a::Union{Vector,Matrix,Adjoint,Transpose}, b::TrackedArray,
|
||||
c::$UArray...) =
|
||||
track($f, a, b, c...) # resolves ambiguity introduced by previous row
|
||||
end
|
||||
end
|
||||
|
@ -157,11 +158,13 @@ end
|
|||
end
|
||||
end
|
||||
|
||||
Base.cat(dims, a::TrackedArray, b::AbstractArray...) = track(cat, dims, a, b...)
|
||||
Base.cat(dims, a::Union{RowVector,Array}, b::TrackedArray, c::AbstractArray...) = track(cat, dims, a, b, c...)
|
||||
Base.cat(a::TrackedArray; dims) = track_kw(cat, a, dims = dims)
|
||||
Base.cat(a::TrackedArray, b::TrackedArray, c::AbstractArray...; dims) = track_kw(cat, a, b, c..., dims = dims)
|
||||
Base.cat(a::TrackedArray, b::AbstractArray, c::AbstractArray...; dims) = track_kw(cat, a, b, c..., dims = dims)
|
||||
Base.cat(a::AbstractArray, b::TrackedArray, c::AbstractArray...; dims) = track_kw(cat, a, b, c..., dims = dims)
|
||||
|
||||
@grad function cat(dims, Xs...)
|
||||
cat(dims, data.(Xs)...), function (Δ)
|
||||
@grad function cat(Xs...; dims)
|
||||
cat(data.(Xs)..., dims = dims), function (Δ)
|
||||
start = ntuple(i -> 0, Val{ndims(Δ)})
|
||||
Δs = [begin
|
||||
dim_xs = 1:ndims(xs)
|
||||
|
@ -171,7 +174,7 @@ Base.cat(dims, a::Union{RowVector,Array}, b::TrackedArray, c::AbstractArray...)
|
|||
start = start .+ till_xs
|
||||
d
|
||||
end for xs in Xs]
|
||||
return (nothing, Δs...,)
|
||||
return (Δs...,)
|
||||
end
|
||||
end
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
using Flux
|
||||
using Flux.Tracker, Base.Test, NNlib
|
||||
using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint
|
||||
using NNlib: conv
|
||||
|
@ -49,8 +50,8 @@ function promotiontest(f, A, B, C)
|
|||
end
|
||||
|
||||
@testset "concat" begin
|
||||
cat1(x...) = cat(1, x...)
|
||||
cat2(x...) = cat(2, x...)
|
||||
cat1(x...) = cat(x..., dims = 1)
|
||||
cat2(x...) = cat(x..., dims = 2)
|
||||
|
||||
@testset for vcatf in [vcat, cat1]
|
||||
@test gradtest(vcatf, rand(5), rand(3))
|
||||
|
@ -72,17 +73,17 @@ end
|
|||
@test gradtest(hcatf, rand(5), rand(5,2))
|
||||
end
|
||||
|
||||
@testset for catf in [vcat, cat1, hcat, cat2, (x...) -> cat(3, x...), (x...) -> cat((1,2), x...)]
|
||||
@testset for catf in [vcat, cat1, hcat, cat2, (x...) -> cat(x..., dims = 3), (x...) -> cat(x..., dims = (1,2))]
|
||||
@test gradtest(catf, rand(5))
|
||||
@test gradtest(catf, rand(5)')
|
||||
@test gradtest(catf, rand(2,5))
|
||||
@test gradtest(catf, rand(2,5,3))
|
||||
end
|
||||
|
||||
@test gradtest((x...) -> cat(3, x...), rand(2,5,2), rand(2,5,3), rand(2,5,4))
|
||||
@test gradtest((x...) -> cat(x..., dims = 3), rand(2,5,2), rand(2,5,3), rand(2,5,4))
|
||||
|
||||
@testset "cat($dim, ...)" for dim in 3:5
|
||||
catdim = (x...) -> cat(dim, x...)
|
||||
catdim = (x...) -> cat(x..., dims = dim)
|
||||
@test gradtest(catdim, rand(5), rand(5), rand(5))
|
||||
@test gradtest(catdim, rand(2,5), rand(2,5), rand(2,5))
|
||||
@test gradtest(catdim, rand(2,5,3), rand(2,5,3), rand(2,5,3))
|
||||
|
@ -92,7 +93,7 @@ end
|
|||
@test !isa(hcat(rand(2)), TrackedArray)
|
||||
@test !isa(cat(1,rand(2)), TrackedArray)
|
||||
|
||||
@test gradtest((a,b)->cat((2,3,5), a, b), rand(2,3), rand(2,4,2,1))
|
||||
@test gradtest((a,b)->cat(a, b, dims = (2,3,5)), rand(2,3), rand(2,4,2,1))
|
||||
|
||||
@testset "promotiontest" begin
|
||||
@testset for fcat in [hcat, vcat, (x...) -> cat(3, x...), (x...) -> cat((1,2), x...)]
|
||||
|
|
Loading…
Reference in New Issue