From 36295799ee3fb4f7c353fbd77f5f90db985aa834 Mon Sep 17 00:00:00 2001 From: Elliot Saba Date: Tue, 27 Feb 2018 18:19:58 -0800 Subject: [PATCH] Add `permutedims()` for tracked arrays --- src/tracker/array.jl | 2 ++ test/tracker.jl | 1 + 2 files changed, 3 insertions(+) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index fd69b803..d5515905 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -114,6 +114,8 @@ Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64,N}} where N) = back(::typeof(reshape), Δ, xs::TrackedArray, _...) = back(xs, reshape(Δ, size(xs))) +Base.permutedims(xs::TrackedArray, dims) = track(permutedims, xs, dims) +back(::typeof(permutedims), Δ, xs::TrackedArray, dims) = back(xs, permutedims(Δ, invperm(dims))) function _kron(mat1::AbstractMatrix,mat2::AbstractMatrix) m1, n1 = size(mat1) diff --git a/test/tracker.jl b/test/tracker.jl index 9aa80b3c..cebc59ec 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -30,6 +30,7 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...) @test gradtest(vcat, rand(5), rand(3)) @test gradtest(vcat, rand(5), rand(3), rand(8)) @test gradtest(vcat, rand(5,2), rand(3,2), rand(8,2)) +@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6)) @test gradtest(kron,rand(5), rand(3)) @test gradtest(kron, rand(5), rand(3), rand(8))