diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 3512d2d7..14e5136f 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -309,8 +309,8 @@ end # BLAS -LinearAlgebra.diagm(x::Pair{<:Integer, <:TrackedVector}) = track(diagm, x) -@grad diagm(x::Pair) = diagm(x[1] => data(x[2])), Δ -> (diag(Δ, x[1]),) +LinearAlgebra.diagm(x::Pair{<:Integer, <:TrackedVector}) = track(diagm, x...) +@grad diagm(i, x) = diagm(i => data(x)), Δ -> (nothing, diag(Δ, i)) x::TrackedMatrix * y::AbstractMatrix = track(*, x, y) x::AbstractMatrix * y::TrackedMatrix = track(*, x, y)