Add `permutedims()` for tracked arrays
This commit is contained in:
parent
bdd8162bf8
commit
36295799ee
|
@ -114,6 +114,8 @@ Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64,N}} where N) =
|
|||
back(::typeof(reshape), Δ, xs::TrackedArray, _...) =
|
||||
back(xs, reshape(Δ, size(xs)))
|
||||
|
||||
Base.permutedims(xs::TrackedArray, dims) = track(permutedims, xs, dims)
|
||||
back(::typeof(permutedims), Δ, xs::TrackedArray, dims) = back(xs, permutedims(Δ, invperm(dims)))
|
||||
|
||||
function _kron(mat1::AbstractMatrix,mat2::AbstractMatrix)
|
||||
m1, n1 = size(mat1)
|
||||
|
|
|
@ -30,6 +30,7 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
|||
@test gradtest(vcat, rand(5), rand(3))
|
||||
@test gradtest(vcat, rand(5), rand(3), rand(8))
|
||||
@test gradtest(vcat, rand(5,2), rand(3,2), rand(8,2))
|
||||
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
|
||||
|
||||
@test gradtest(kron,rand(5), rand(3))
|
||||
@test gradtest(kron, rand(5), rand(3), rand(8))
|
||||
|
|
Loading…
Reference in New Issue