diff --git a/src/tracker/lib/array.jl b/src/tracker/lib/array.jl index 432244ce..3f607805 100644 --- a/src/tracker/lib/array.jl +++ b/src/tracker/lib/array.jl @@ -320,8 +320,8 @@ end # BLAS -LinearAlgebra.diagm(x::TrackedVector) = track(diagm, x) -@grad diagm(x) = diagm(data(x)), Δ -> (diag(Δ),) +LinearAlgebra.diagm(x::Pair{<:Integer, <:TrackedVector}) = track(diagm, x...) +@grad diagm(i, x) = diagm(i => data(x)), Δ -> (nothing, diag(Δ, i)) x::TrackedMatrix * y::AbstractMatrix = track(*, x, y) x::AbstractMatrix * y::TrackedMatrix = track(*, x, y) diff --git a/test/tracker.jl b/test/tracker.jl index b12e3bb6..93f6c6ce 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -3,7 +3,7 @@ using Flux.Tracker, Test, NNlib using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint using NNlib: conv, depthwiseconv using Printf: @sprintf -using LinearAlgebra: Diagonal, dot, LowerTriangular, norm +using LinearAlgebra: diagm, dot, LowerTriangular, norm using Statistics: mean, std using Random # using StatsBase @@ -132,7 +132,7 @@ end @test gradtest(kron, rand(5,1), rand(3,1), rand(8,1)) @test gradtest(kron, rand(5,2), rand(3,2), rand(8,2)) -@test gradtest(f-> Matrix(Diagonal(f)), rand(3)) +@test gradtest(x -> diagm(0 => x), rand(3)) @test gradtest(W -> inv(log.(W * W)), (5,5)) @test gradtest((A, B) -> A / B , (1,5), (5,5))