Merge pull request #187 from staticfloat/patch-1

Add `permutedims()` for tracked arrays
This commit is contained in:
Mike J Innes 2018-03-02 18:23:19 +00:00 committed by GitHub
commit eaa9fd2dd3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 3 additions and 0 deletions

View File

@ -114,6 +114,8 @@ Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64,N}} where N) =
back(::typeof(reshape), Δ, xs::TrackedArray, _...) =
back(xs, reshape(Δ, size(xs)))
Base.permutedims(xs::TrackedArray, dims) = track(permutedims, xs, dims)
back(::typeof(permutedims), Δ, xs::TrackedArray, dims) = back(xs, permutedims(Δ, invperm(dims)))
function _kron(mat1::AbstractMatrix,mat2::AbstractMatrix)
m1, n1 = size(mat1)

View File

@ -30,6 +30,7 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest(vcat, rand(5), rand(3))
@test gradtest(vcat, rand(5), rand(3), rand(8))
@test gradtest(vcat, rand(5,2), rand(3,2), rand(8,2))
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
@test gradtest(kron,rand(5), rand(3))
@test gradtest(kron, rand(5), rand(3), rand(8))