diff --git a/src/tracker/array.jl b/src/tracker/array.jl index c75b5c1c..4ce0a730 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -309,6 +309,9 @@ end # BLAS +LinearAlgebra.diagm(x::Pair{<:Integer, <:TrackedVector}) = track(diagm, x) +@grad diagm(x::Pair) = diagm(x[1] => data(x[2])), Δ -> (diag(Δ, x[1]),) + LinearAlgebra.diagm(x::TrackedVector) = track(diagm, x) @grad diagm(x) = diagm(data(x)), Δ -> (diag(Δ),)