vcat back
This commit is contained in:
parent
788d7d35f0
commit
a322c07fc8
@ -29,6 +29,16 @@ Base.vcat(a::TrackedVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b))
|
|||||||
Base.vcat(a::TrackedVector, b::AbstractVector) = TrackedArray(Call(vcat, a, b))
|
Base.vcat(a::TrackedVector, b::AbstractVector) = TrackedArray(Call(vcat, a, b))
|
||||||
Base.vcat(a::AbstractVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b))
|
Base.vcat(a::AbstractVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b))
|
||||||
|
|
||||||
|
Base.vcat(a::TrackedArray, b::TrackedArray) = TrackedArray(Call(vcat, a, b))
|
||||||
|
Base.vcat(a::TrackedArray, b::AbstractArray) = TrackedArray(Call(vcat, a, b))
|
||||||
|
Base.vcat(a::AbstractArray, b::TrackedArray) = TrackedArray(Call(vcat, a, b))
|
||||||
|
|
||||||
|
function back!(::typeof(vcat), Δ, xs, ys)
|
||||||
|
i = Base.tail(map(_ -> :, size(Δ)))
|
||||||
|
@back!(xs, Δ[1:size(xs,1), i...])
|
||||||
|
@back!(ys, Δ[size(xs,1)+1:end, i...])
|
||||||
|
end
|
||||||
|
|
||||||
# Reductions
|
# Reductions
|
||||||
|
|
||||||
Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim))
|
Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim))
|
||||||
|
@ -19,4 +19,7 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
|||||||
|
|
||||||
@test gradtest(x -> x', rand(5))
|
@test gradtest(x -> x', rand(5))
|
||||||
|
|
||||||
|
@test gradtest(vcat, rand(5), rand(3))
|
||||||
|
@test gradtest(vcat, rand(2,3), rand(3,3))
|
||||||
|
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user