PermutedDimsArray like permutedims
e.g. PermutedDimsArray(rand(2,3) |> param, (2,1))
This commit is contained in:
parent
8386a49bf9
commit
031d1b3d57
@ -223,8 +223,11 @@ Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64}}) = track(reshape, xs,
|
||||
|
||||
@grad reshape(xs, dims) = reshape(data(xs), dims), Δ -> (reshape(Δ, size(xs)),nothing)
|
||||
|
||||
Base.permutedims(xs::TrackedArray, dims) = track(permutedims, xs, dims)
|
||||
@grad permutedims(xs, dims) = permutedims(data(xs), dims), Δ -> (permutedims(Δ, invperm(dims)),nothing)
|
||||
Base.permutedims(xs::TrackedArray, perm) = track(permutedims, xs, perm)
|
||||
@grad permutedims(xs, perm) = permutedims(data(xs), perm), Δ -> (permutedims(Δ, invperm(perm)),nothing)
|
||||
|
||||
Base.PermutedDimsArray(xs::TrackedArray, perm) = track(PermutedDimsArray, xs, perm)
|
||||
@grad PermutedDimsArray(xs, perm) = PermutedDimsArray(data(xs), perm), Δ -> (PermutedDimsArray(Δ, invperm(perm)),nothing)
|
||||
|
||||
function _kron(mat1::AbstractMatrix,mat2::AbstractMatrix)
|
||||
m1, n1 = size(mat1)
|
||||
|
Loading…
Reference in New Issue
Block a user