From a322c07fc81dcd86f2f6f2318aac94e4d4107bff Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 5 Sep 2017 02:11:28 -0400 Subject: [PATCH] vcat back --- src/tracker/lib.jl | 10 ++++++++++ test/tracker.jl | 3 +++ 2 files changed, 13 insertions(+) diff --git a/src/tracker/lib.jl b/src/tracker/lib.jl index 8031db0f..03046e87 100644 --- a/src/tracker/lib.jl +++ b/src/tracker/lib.jl @@ -29,6 +29,16 @@ Base.vcat(a::TrackedVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b)) Base.vcat(a::TrackedVector, b::AbstractVector) = TrackedArray(Call(vcat, a, b)) Base.vcat(a::AbstractVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b)) +Base.vcat(a::TrackedArray, b::TrackedArray) = TrackedArray(Call(vcat, a, b)) +Base.vcat(a::TrackedArray, b::AbstractArray) = TrackedArray(Call(vcat, a, b)) +Base.vcat(a::AbstractArray, b::TrackedArray) = TrackedArray(Call(vcat, a, b)) + +function back!(::typeof(vcat), Δ, xs, ys) + i = Base.tail(map(_ -> :, size(Δ))) + @back!(xs, Δ[1:size(xs,1), i...]) + @back!(ys, Δ[size(xs,1)+1:end, i...]) +end + # Reductions Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim)) diff --git a/test/tracker.jl b/test/tracker.jl index 258e1af4..0fa4598b 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -19,4 +19,7 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...) @test gradtest(x -> x', rand(5)) +@test gradtest(vcat, rand(5), rand(3)) +@test gradtest(vcat, rand(2,3), rand(3,3)) + end