Merge pull request #467 from invenia/ed/diagm-pair

Add new-style diagm to tracker
This commit is contained in:
Mike J Innes 2018-11-05 12:23:16 +00:00
parent 96dbae2d20
commit 63a3acbf5e
2 changed files with 4 additions and 4 deletions

View File

@ -309,8 +309,8 @@ end
# BLAS
LinearAlgebra.diagm(x::TrackedVector) = track(diagm, x)
@grad diagm(x) = diagm(data(x)), Δ -> (diag(Δ),)
LinearAlgebra.diagm(x::Pair{<:Integer, <:TrackedVector}) = track(diagm, x...)
@grad diagm(i, x) = diagm(i => data(x)), Δ -> (nothing, diag(Δ, i))
x::TrackedMatrix * y::AbstractMatrix = track(*, x, y)
x::AbstractMatrix * y::TrackedMatrix = track(*, x, y)

View File

@ -3,7 +3,7 @@ using Flux.Tracker, Test, NNlib
using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint
using NNlib: conv
using Printf: @sprintf
using LinearAlgebra: Diagonal, dot, LowerTriangular, norm
using LinearAlgebra: diagm, dot, LowerTriangular, norm
using Statistics: mean, std
using Random
# using StatsBase
@ -127,7 +127,7 @@ end
@test gradtest(kron, rand(5,1), rand(3,1), rand(8,1))
@test gradtest(kron, rand(5,2), rand(3,2), rand(8,2))
@test gradtest(f-> Matrix(Diagonal(f)), rand(3))
@test gradtest(x -> diagm(0 => x), rand(3))
@test gradtest(W -> inv(log.(W * W)), (5,5))
@test gradtest((A, B) -> A / B , (1,5), (5,5))