This commit is contained in:
Mike J Innes 2018-07-12 20:42:32 +01:00
parent adc216f182
commit 00cfe24d66
2 changed files with 19 additions and 15 deletions

View File

@ -103,6 +103,7 @@ Base.repeat(A::TrackedArray; kw...) = track_kw(repeat, A; kw...)
end
for f in [:vcat, :hcat]
UArray = :(Union{TrackedArray,Vector,Matrix,Adjoint,Transpose})
@eval begin
# This section is a bit of a hack since julia doesn't have a standardised
# promotion mechanism for concatenation yet
@ -111,18 +112,18 @@ for f in [:vcat, :hcat]
# It should support tracked concatenation with rank ∈ (1,2) with a
# TrackedArray anywhere among the arguments This works as long as base has
# other functions that captures `(::Union{Vector,RowVector,Matrix}...)`.
Base.$f(a::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a...)
Base.$f(a::$UArray...) = track($f, a...)
# It should support tracked concatenation with rank>2 if the TrackedArray is
# first
Base.$f(a::TrackedArray, b::AbstractArray...) = track($f, a, b...)
Base.$f(a::TrackedArray, b::Union{TrackedArray,Vector,RowVector,Matrix}...) = track($f, a, b...) # resolves ambiguity introduced by previous row
Base.$f(a::TrackedArray, b::$UArray...) = track($f, a, b...) # resolves ambiguity introduced by previous row
# It should support tracked concatenation with rank>2 if the TrackedArray is
# second
Base.$f(a::Array, b::TrackedArray, c::AbstractArray...) = track($f, a, b, c...)
Base.$f(a::Union{Vector,RowVector,Matrix}, b::TrackedArray,
c::Union{TrackedArray,Vector,RowVector,Matrix}...) =
Base.$f(a::Union{Vector,Matrix,Adjoint,Transpose}, b::TrackedArray,
c::$UArray...) =
track($f, a, b, c...) # resolves ambiguity introduced by previous row
end
end
@ -157,11 +158,13 @@ end
end
end
Base.cat(dims, a::TrackedArray, b::AbstractArray...) = track(cat, dims, a, b...)
Base.cat(dims, a::Union{RowVector,Array}, b::TrackedArray, c::AbstractArray...) = track(cat, dims, a, b, c...)
Base.cat(a::TrackedArray; dims) = track_kw(cat, a, dims = dims)
Base.cat(a::TrackedArray, b::TrackedArray, c::AbstractArray...; dims) = track_kw(cat, a, b, c..., dims = dims)
Base.cat(a::TrackedArray, b::AbstractArray, c::AbstractArray...; dims) = track_kw(cat, a, b, c..., dims = dims)
Base.cat(a::AbstractArray, b::TrackedArray, c::AbstractArray...; dims) = track_kw(cat, a, b, c..., dims = dims)
@grad function cat(dims, Xs...)
cat(dims, data.(Xs)...), function (Δ)
@grad function cat(Xs...; dims)
cat(data.(Xs)..., dims = dims), function (Δ)
start = ntuple(i -> 0, Val{ndims(Δ)})
Δs = [begin
dim_xs = 1:ndims(xs)
@ -171,7 +174,7 @@ Base.cat(dims, a::Union{RowVector,Array}, b::TrackedArray, c::AbstractArray...)
start = start .+ till_xs
d
end for xs in Xs]
return (nothing, Δs...,)
return (Δs...,)
end
end

View File

@ -1,3 +1,4 @@
using Flux
using Flux.Tracker, Base.Test, NNlib
using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint
using NNlib: conv
@ -49,8 +50,8 @@ function promotiontest(f, A, B, C)
end
@testset "concat" begin
cat1(x...) = cat(1, x...)
cat2(x...) = cat(2, x...)
cat1(x...) = cat(x..., dims = 1)
cat2(x...) = cat(x..., dims = 2)
@testset for vcatf in [vcat, cat1]
@test gradtest(vcatf, rand(5), rand(3))
@ -72,17 +73,17 @@ end
@test gradtest(hcatf, rand(5), rand(5,2))
end
@testset for catf in [vcat, cat1, hcat, cat2, (x...) -> cat(3, x...), (x...) -> cat((1,2), x...)]
@testset for catf in [vcat, cat1, hcat, cat2, (x...) -> cat(x..., dims = 3), (x...) -> cat(x..., dims = (1,2))]
@test gradtest(catf, rand(5))
@test gradtest(catf, rand(5)')
@test gradtest(catf, rand(2,5))
@test gradtest(catf, rand(2,5,3))
end
@test gradtest((x...) -> cat(3, x...), rand(2,5,2), rand(2,5,3), rand(2,5,4))
@test gradtest((x...) -> cat(x..., dims = 3), rand(2,5,2), rand(2,5,3), rand(2,5,4))
@testset "cat($dim, ...)" for dim in 3:5
catdim = (x...) -> cat(dim, x...)
catdim = (x...) -> cat(x..., dims = dim)
@test gradtest(catdim, rand(5), rand(5), rand(5))
@test gradtest(catdim, rand(2,5), rand(2,5), rand(2,5))
@test gradtest(catdim, rand(2,5,3), rand(2,5,3), rand(2,5,3))
@ -92,7 +93,7 @@ end
@test !isa(hcat(rand(2)), TrackedArray)
@test !isa(cat(1,rand(2)), TrackedArray)
@test gradtest((a,b)->cat((2,3,5), a, b), rand(2,3), rand(2,4,2,1))
@test gradtest((a,b)->cat(a, b, dims = (2,3,5)), rand(2,3), rand(2,4,2,1))
@testset "promotiontest" begin
@testset for fcat in [hcat, vcat, (x...) -> cat(3, x...), (x...) -> cat((1,2), x...)]