Merge pull request #169 from jessebett/jessechanges

Reshape with Tuple Dimensions and Kronecker Product
This commit is contained in:
Mike J Innes 2018-02-16 14:16:42 +00:00 committed by GitHub
commit e3b31b9b87
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 24 additions and 0 deletions

View File

@ -108,9 +108,27 @@ end
Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) =
track(reshape, xs, dims...)
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64,N}} where N) =
track(reshape, xs, dims)
back(::typeof(reshape), Δ, xs::TrackedArray, _...) =
back(xs, reshape(Δ, size(xs)))
function _kron(mat1::AbstractMatrix,mat2::AbstractMatrix)
m1, n1 = size(mat1)
mat1_rsh = reshape(mat1,(1,m1,1,n1))
m2, n2 = size(mat2)
mat2_rsh = reshape(mat2,(m2,1,n2,1))
return reshape(mat1_rsh.*mat2_rsh, (m1*m2,n1*n2))
end
Base.kron(a::TrackedMatrix, b::TrackedMatrix) = _kron(a, b)
Base.kron(a::TrackedMatrix, b::AbstractMatrix) = _kron(a, b)
Base.kron(a::AbstractMatrix, b::TrackedMatrix) = _kron(a, b)
# Reductions
Base.sum(xs::TrackedArray, dim) = track(sum, xs, dim)

View File

@ -31,6 +31,12 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest(vcat, rand(5), rand(3), rand(8))
@test gradtest(vcat, rand(5,2), rand(3,2), rand(8,2))
@test gradtest(kron,rand(5), rand(3))
@test gradtest(kron, rand(5), rand(3), rand(8))
@test gradtest(kron,rand(5,1), rand(3,1))
@test gradtest(kron, rand(5,1), rand(3,1), rand(8,1))
@test gradtest(kron, rand(5,2), rand(3,2), rand(8,2))
@test gradtest(diagm, rand(3))
@testset "mean" begin