PermutedDimsArray like permutedims

e.g. PermutedDimsArray(rand(2,3) |> param, (2,1))
This commit is contained in:
Michael Abbott 2019-01-28 18:15:32 +01:00 committed by GitHub
parent 8386a49bf9
commit 031d1b3d57
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -223,8 +223,11 @@ Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64}}) = track(reshape, xs,
@grad reshape(xs, dims) = reshape(data(xs), dims), Δ -> (reshape(Δ, size(xs)),nothing)
Base.permutedims(xs::TrackedArray, dims) = track(permutedims, xs, dims)
@grad permutedims(xs, dims) = permutedims(data(xs), dims), Δ -> (permutedims(Δ, invperm(dims)),nothing)
Base.permutedims(xs::TrackedArray, perm) = track(permutedims, xs, perm)
@grad permutedims(xs, perm) = permutedims(data(xs), perm), Δ -> (permutedims(Δ, invperm(perm)),nothing)
Base.PermutedDimsArray(xs::TrackedArray, perm) = track(PermutedDimsArray, xs, perm)
@grad PermutedDimsArray(xs, perm) = PermutedDimsArray(data(xs), perm), Δ -> (PermutedDimsArray(Δ, invperm(perm)),nothing)
function _kron(mat1::AbstractMatrix,mat2::AbstractMatrix)
m1, n1 = size(mat1)