diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 93ec7bce..4d547d5a 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -106,9 +106,23 @@ end Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) = track(reshape, xs, dims...) +Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64,N}} where N) = + track(reshape, xs, dims) + back(::typeof(reshape), Δ, xs::TrackedArray, _...) = back(xs, reshape(Δ, size(xs))) + +function Base.kron(mat1::AbstractMatrix,mat2::AbstractMatrix) + m1, n1 = size(mat1) + mat1_rsh = reshape(mat1,(1,m1,1,n1)) + + m2, n2 = size(mat2) + mat2_rsh = reshape(mat2,(m2,1,n2,1)) + + return reshape(mat1_rsh.*mat2_rsh, (m1*m2,n1*n2)) +end + # Reductions Base.sum(xs::TrackedArray, dim) = track(sum, xs, dim)