vcat back

This commit is contained in:
Mike J Innes 2017-09-05 02:11:28 -04:00
parent 788d7d35f0
commit a322c07fc8
2 changed files with 13 additions and 0 deletions

View File

@ -29,6 +29,16 @@ Base.vcat(a::TrackedVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b))
Base.vcat(a::TrackedVector, b::AbstractVector) = TrackedArray(Call(vcat, a, b)) Base.vcat(a::TrackedVector, b::AbstractVector) = TrackedArray(Call(vcat, a, b))
Base.vcat(a::AbstractVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b)) Base.vcat(a::AbstractVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b))
Base.vcat(a::TrackedArray, b::TrackedArray) = TrackedArray(Call(vcat, a, b))
Base.vcat(a::TrackedArray, b::AbstractArray) = TrackedArray(Call(vcat, a, b))
Base.vcat(a::AbstractArray, b::TrackedArray) = TrackedArray(Call(vcat, a, b))
function back!(::typeof(vcat), Δ, xs, ys)
i = Base.tail(map(_ -> :, size(Δ)))
@back!(xs, Δ[1:size(xs,1), i...])
@back!(ys, Δ[size(xs,1)+1:end, i...])
end
# Reductions # Reductions
Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim)) Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim))

View File

@ -19,4 +19,7 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest(x -> x', rand(5)) @test gradtest(x -> x', rand(5))
@test gradtest(vcat, rand(5), rand(3))
@test gradtest(vcat, rand(2,3), rand(3,3))
end end