Add `permutedims()` for tracked arrays

This commit is contained in:
Elliot Saba 2018-02-27 18:19:58 -08:00
parent bdd8162bf8
commit 36295799ee
2 changed files with 3 additions and 0 deletions

View File

@ -114,6 +114,8 @@ Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64,N}} where N) =
back(::typeof(reshape), Δ, xs::TrackedArray, _...) =
back(xs, reshape(Δ, size(xs)))
Base.permutedims(xs::TrackedArray, dims) = track(permutedims, xs, dims)
back(::typeof(permutedims), Δ, xs::TrackedArray, dims) = back(xs, permutedims(Δ, invperm(dims)))
function _kron(mat1::AbstractMatrix,mat2::AbstractMatrix)
m1, n1 = size(mat1)

View File

@ -30,6 +30,7 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest(vcat, rand(5), rand(3))
@test gradtest(vcat, rand(5), rand(3), rand(8))
@test gradtest(vcat, rand(5,2), rand(3,2), rand(8,2))
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
@test gradtest(kron,rand(5), rand(3))
@test gradtest(kron, rand(5), rand(3), rand(8))