From 9f9803eec678504c20346d149aaa0cc44461ada0 Mon Sep 17 00:00:00 2001 From: Eric Davies Date: Fri, 26 Oct 2018 13:39:49 -0500 Subject: [PATCH 1/3] Add new-style diagm to tracker --- src/tracker/array.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index c75b5c1c..4ce0a730 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -309,6 +309,9 @@ end # BLAS +LinearAlgebra.diagm(x::Pair{<:Integer, <:TrackedVector}) = track(diagm, x) +@grad diagm(x::Pair) = diagm(x[1] => data(x[2])), Δ -> (diag(Δ, x[1]),) + LinearAlgebra.diagm(x::TrackedVector) = track(diagm, x) @grad diagm(x) = diagm(data(x)), Δ -> (diag(Δ),) From 77178b7d674ba34884c0542eabf0bd4c4ee0476e Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 30 Oct 2018 14:21:22 +0000 Subject: [PATCH 2/3] remove old-style definition and test --- src/tracker/array.jl | 3 --- test/tracker.jl | 4 ++-- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 4ce0a730..3512d2d7 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -312,9 +312,6 @@ end LinearAlgebra.diagm(x::Pair{<:Integer, <:TrackedVector}) = track(diagm, x) @grad diagm(x::Pair) = diagm(x[1] => data(x[2])), Δ -> (diag(Δ, x[1]),) -LinearAlgebra.diagm(x::TrackedVector) = track(diagm, x) -@grad diagm(x) = diagm(data(x)), Δ -> (diag(Δ),) - x::TrackedMatrix * y::AbstractMatrix = track(*, x, y) x::AbstractMatrix * y::TrackedMatrix = track(*, x, y) x::TrackedMatrix * y::TrackedMatrix = track(*, x, y) diff --git a/test/tracker.jl b/test/tracker.jl index 1f5f6240..ea932815 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -3,7 +3,7 @@ using Flux.Tracker, Test, NNlib using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint using NNlib: conv, depthwiseconv using Printf: @sprintf -using LinearAlgebra: Diagonal, dot, LowerTriangular, norm +using LinearAlgebra: diagm, dot, LowerTriangular, norm using Statistics: mean, std using Random # using StatsBase @@ -127,7 +127,7 @@ end @test gradtest(kron, rand(5,1), rand(3,1), rand(8,1)) @test gradtest(kron, rand(5,2), rand(3,2), rand(8,2)) -@test gradtest(f-> Matrix(Diagonal(f)), rand(3)) +@test gradtest(x -> diagm(0 => x), rand(3)) @test gradtest(W -> inv(log.(W * W)), (5,5)) @test gradtest((A, B) -> A / B , (1,5), (5,5)) From 5df48fbc5d3463dc6f5819fc002449fd2ab01efd Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 5 Nov 2018 11:49:38 +0000 Subject: [PATCH 3/3] fix --- src/tracker/array.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 3512d2d7..14e5136f 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -309,8 +309,8 @@ end # BLAS -LinearAlgebra.diagm(x::Pair{<:Integer, <:TrackedVector}) = track(diagm, x) -@grad diagm(x::Pair) = diagm(x[1] => data(x[2])), Δ -> (diag(Δ, x[1]),) +LinearAlgebra.diagm(x::Pair{<:Integer, <:TrackedVector}) = track(diagm, x...) +@grad diagm(i, x) = diagm(i => data(x)), Δ -> (nothing, diag(Δ, i)) x::TrackedMatrix * y::AbstractMatrix = track(*, x, y) x::AbstractMatrix * y::TrackedMatrix = track(*, x, y)