reshape with tupled dimensions and kronecker product

This commit is contained in:
jessebett 2018-02-08 14:27:57 -05:00
parent 4511936a87
commit f84ee8eab0

View File

@ -106,9 +106,23 @@ end
Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) = Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) =
track(reshape, xs, dims...) track(reshape, xs, dims...)
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64,N}} where N) =
track(reshape, xs, dims)
back(::typeof(reshape), Δ, xs::TrackedArray, _...) = back(::typeof(reshape), Δ, xs::TrackedArray, _...) =
back(xs, reshape(Δ, size(xs))) back(xs, reshape(Δ, size(xs)))
function Base.kron(mat1::AbstractMatrix,mat2::AbstractMatrix)
m1, n1 = size(mat1)
mat1_rsh = reshape(mat1,(1,m1,1,n1))
m2, n2 = size(mat2)
mat2_rsh = reshape(mat2,(m2,1,n2,1))
return reshape(mat1_rsh.*mat2_rsh, (m1*m2,n1*n2))
end
# Reductions # Reductions
Base.sum(xs::TrackedArray, dim) = track(sum, xs, dim) Base.sum(xs::TrackedArray, dim) = track(sum, xs, dim)