reshape with tupled dimensions and kronecker product
This commit is contained in:
parent
4511936a87
commit
f84ee8eab0
@ -106,9 +106,23 @@ end
|
||||
Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) =
|
||||
track(reshape, xs, dims...)
|
||||
|
||||
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64,N}} where N) =
|
||||
track(reshape, xs, dims)
|
||||
|
||||
back(::typeof(reshape), Δ, xs::TrackedArray, _...) =
|
||||
back(xs, reshape(Δ, size(xs)))
|
||||
|
||||
|
||||
function Base.kron(mat1::AbstractMatrix,mat2::AbstractMatrix)
|
||||
m1, n1 = size(mat1)
|
||||
mat1_rsh = reshape(mat1,(1,m1,1,n1))
|
||||
|
||||
m2, n2 = size(mat2)
|
||||
mat2_rsh = reshape(mat2,(m2,1,n2,1))
|
||||
|
||||
return reshape(mat1_rsh.*mat2_rsh, (m1*m2,n1*n2))
|
||||
end
|
||||
|
||||
# Reductions
|
||||
|
||||
Base.sum(xs::TrackedArray, dim) = track(sum, xs, dim)
|
||||
|
Loading…
Reference in New Issue
Block a user