commit
0469394715
|
@ -223,8 +223,11 @@ Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64}}) = track(reshape, xs,
|
|||
|
||||
@grad reshape(xs, dims) = reshape(data(xs), dims), Δ -> (reshape(Δ, size(xs)),nothing)
|
||||
|
||||
Base.permutedims(xs::TrackedArray, dims) = track(permutedims, xs, dims)
|
||||
@grad permutedims(xs, dims) = permutedims(data(xs), dims), Δ -> (permutedims(Δ, invperm(dims)),nothing)
|
||||
Base.permutedims(xs::TrackedArray, perm) = track(permutedims, xs, perm)
|
||||
@grad permutedims(xs, perm) = permutedims(data(xs), perm), Δ -> (permutedims(Δ, invperm(perm)),nothing)
|
||||
|
||||
Base.PermutedDimsArray(xs::TrackedArray, perm) = track(PermutedDimsArray, xs, perm)
|
||||
@grad PermutedDimsArray(xs, perm) = PermutedDimsArray(data(xs), perm), Δ -> (PermutedDimsArray(Δ, invperm(perm)),nothing)
|
||||
|
||||
function _kron(mat1::AbstractMatrix,mat2::AbstractMatrix)
|
||||
m1, n1 = size(mat1)
|
||||
|
|
|
@ -116,6 +116,7 @@ end
|
|||
end
|
||||
|
||||
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
|
||||
@test gradtest(x -> PermutedDimsArray(x, [3,1,2]), rand(4,5,6))
|
||||
|
||||
@test gradtest(x -> repeat(x; inner=2), rand(5))
|
||||
@test gradtest(x -> repeat(x; inner=2, outer=3), rand(5))
|
||||
|
|
Loading…
Reference in New Issue