Merge pull request #123 from GenaBitu/cat-fix
Added vcat for multiple TrackedVectors
This commit is contained in:
commit
2fec75005d
|
@ -27,21 +27,27 @@ Base.repmat(x::TrackedVecOrMat, a::Integer...) = TrackedArray(Call(repmat, x, a.
|
|||
Base.repmat(x::TrackedVecOrMat, a::Int64...) = TrackedArray(Call(repmat, x, a...))
|
||||
|
||||
Base.vcat(a::TrackedVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b))
|
||||
Base.vcat(a::TrackedVector, b::TrackedVector...) = TrackedArray(Call(vcat, a, b...))
|
||||
Base.vcat(a::TrackedVector, b::AbstractVector) = TrackedArray(Call(vcat, a, b))
|
||||
Base.vcat(a::AbstractVector, b::TrackedVector) = TrackedArray(Call(vcat, a, b))
|
||||
|
||||
Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat) = TrackedArray(Call(vcat, a, b))
|
||||
Base.vcat(a::TrackedVecOrMat, b::TrackedVecOrMat...) = TrackedArray(Call(vcat, a, b...))
|
||||
Base.vcat(a::TrackedVecOrMat, b::AbstractVecOrMat) = TrackedArray(Call(vcat, a, b))
|
||||
Base.vcat(a::AbstractVecOrMat, b::TrackedVecOrMat) = TrackedArray(Call(vcat, a, b))
|
||||
|
||||
Base.vcat(a::TrackedMatrix, b::TrackedMatrix) = TrackedArray(Call(vcat, a, b))
|
||||
Base.vcat(a::TrackedMatrix, b::TrackedMatrix...) = TrackedArray(Call(vcat, a, b...))
|
||||
Base.vcat(a::TrackedMatrix, b::AbstractMatrix) = TrackedArray(Call(vcat, a, b))
|
||||
Base.vcat(a::AbstractMatrix, b::TrackedMatrix) = TrackedArray(Call(vcat, a, b))
|
||||
|
||||
function back(::typeof(vcat), Δ, xs, ys)
|
||||
function back(::typeof(vcat), Δ, xs...)
|
||||
i = Base.tail(map(_ -> :, size(Δ)))
|
||||
@back(xs, Δ[1:size(xs,1), i...])
|
||||
@back(ys, Δ[size(xs,1)+1:end, i...])
|
||||
start = 0
|
||||
for xsi in xs
|
||||
@back(xsi, Δ[start+1:start+size(xsi,1), i...])
|
||||
start += size(xsi, 1)
|
||||
end
|
||||
end
|
||||
|
||||
Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) =
|
||||
|
|
|
@ -2,7 +2,7 @@ using Flux.Tracker, Base.Test, NNlib
|
|||
using Flux.Tracker: gradcheck
|
||||
using NNlib
|
||||
|
||||
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(f(xs...)), xs...)
|
||||
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
|
||||
gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
||||
|
||||
@testset "Tracker" begin
|
||||
|
@ -28,7 +28,8 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
|||
@test gradtest(x -> x', rand(5))
|
||||
|
||||
@test gradtest(vcat, rand(5), rand(3))
|
||||
@test gradtest(vcat, rand(2,3), rand(3,3))
|
||||
@test gradtest(vcat, rand(5), rand(3), rand(8))
|
||||
@test gradtest(vcat, rand(5,2), rand(3,2), rand(8,2))
|
||||
|
||||
@testset "mean" begin
|
||||
@test gradtest(mean, rand(2, 3))
|
||||
|
|
Loading…
Reference in New Issue