Merge pull request #123 from GenaBitu/cat-fix
Added vcat for multiple TrackedVectors
This commit is contained in:
commit
2fec75005d
@ -27,21 +27,27 @@ Base.repmat(x::TrackedVecOrMat, a::Integer...) = TrackedArray(Call(repmat, x, a.
|
|||||||
Base.repmat(x::TrackedVecOrMat, a::Int64...) = TrackedArray(Call(repmat, x, a...))
|
Base.repmat(x::TrackedVecOrMat, a::Int64...) = TrackedArray(Call(repmat, x, a...))
|
||||||
|
|
||||||
Base.vcat(a::TrackedVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b))
|
Base.vcat(a::TrackedVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b))
|
||||||
|
Base.vcat(a::TrackedVector, b::TrackedVector...) = TrackedArray(Call(vcat, a, b...))
|
||||||
Base.vcat(a::TrackedVector, b::AbstractVector) = TrackedArray(Call(vcat, a, b))
|
Base.vcat(a::TrackedVector, b::AbstractVector) = TrackedArray(Call(vcat, a, b))
|
||||||
Base.vcat(a::AbstractVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b))
|
Base.vcat(a::AbstractVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b))
|
||||||
|
|
||||||
Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat) = TrackedArray(Call(vcat, a, b))
|
Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat) = TrackedArray(Call(vcat, a, b))
|
||||||
|
Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat...) = TrackedArray(Call(vcat, a, b...))
|
||||||
Base.vcat(a::TrackedVecOrMat, b::AbstractVecOrMat) = TrackedArray(Call(vcat, a, b))
|
Base.vcat(a::TrackedVecOrMat, b::AbstractVecOrMat) = TrackedArray(Call(vcat, a, b))
|
||||||
Base.vcat(a::AbstractVecOrMat, b::TrackedVecOrMat) = TrackedArray(Call(vcat, a, b))
|
Base.vcat(a::AbstractVecOrMat, b::TrackedVecOrMat) = TrackedArray(Call(vcat, a, b))
|
||||||
|
|
||||||
Base.vcat(a::TrackedMatrix, b::TrackedMatrix) = TrackedArray(Call(vcat, a, b))
|
Base.vcat(a::TrackedMatrix, b::TrackedMatrix) = TrackedArray(Call(vcat, a, b))
|
||||||
|
Base.vcat(a::TrackedMatrix, b::TrackedMatrix...) = TrackedArray(Call(vcat, a, b...))
|
||||||
Base.vcat(a::TrackedMatrix, b::AbstractMatrix) = TrackedArray(Call(vcat, a, b))
|
Base.vcat(a::TrackedMatrix, b::AbstractMatrix) = TrackedArray(Call(vcat, a, b))
|
||||||
Base.vcat(a::AbstractMatrix, b::TrackedMatrix) = TrackedArray(Call(vcat, a, b))
|
Base.vcat(a::AbstractMatrix, b::TrackedMatrix) = TrackedArray(Call(vcat, a, b))
|
||||||
|
|
||||||
function back(::typeof(vcat), Δ, xs, ys)
|
function back(::typeof(vcat), Δ, xs...)
|
||||||
i = Base.tail(map(_ -> :, size(Δ)))
|
i = Base.tail(map(_ -> :, size(Δ)))
|
||||||
@back(xs, Δ[1:size(xs,1), i...])
|
start = 0
|
||||||
@back(ys, Δ[size(xs,1)+1:end, i...])
|
for xsi in xs
|
||||||
|
@back(xsi, Δ[start+1:start+size(xsi,1), i...])
|
||||||
|
start += size(xsi, 1)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) =
|
Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) =
|
||||||
|
@ -2,7 +2,7 @@ using Flux.Tracker, Base.Test, NNlib
|
|||||||
using Flux.Tracker: gradcheck
|
using Flux.Tracker: gradcheck
|
||||||
using NNlib
|
using NNlib
|
||||||
|
|
||||||
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(f(xs...)), xs...)
|
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
|
||||||
gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
||||||
|
|
||||||
@testset "Tracker" begin
|
@testset "Tracker" begin
|
||||||
@ -28,7 +28,8 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
|||||||
@test gradtest(x -> x', rand(5))
|
@test gradtest(x -> x', rand(5))
|
||||||
|
|
||||||
@test gradtest(vcat, rand(5), rand(3))
|
@test gradtest(vcat, rand(5), rand(3))
|
||||||
@test gradtest(vcat, rand(2,3), rand(3,3))
|
@test gradtest(vcat, rand(5), rand(3), rand(8))
|
||||||
|
@test gradtest(vcat, rand(5,2), rand(3,2), rand(8,2))
|
||||||
|
|
||||||
@testset "mean" begin
|
@testset "mean" begin
|
||||||
@test gradtest(mean, rand(2, 3))
|
@test gradtest(mean, rand(2, 3))
|
||||||
|
Loading…
Reference in New Issue
Block a user