Added tracking of `logdet` and `logabsdet`. Added gradtests.

This commit is contained in:
Ivan Yashchuk 2019-02-08 09:55:33 +02:00
parent f790fff59a
commit e00ac88016
2 changed files with 12 additions and 2 deletions

View File

@ -1,7 +1,7 @@
import Base: *
import LinearAlgebra
import LinearAlgebra: inv, det, \, /
import LinearAlgebra: inv, det, logdet, logabsdet, \, /
using Statistics
using LinearAlgebra: Transpose, Adjoint, diagm, diag
@ -127,6 +127,12 @@ Base.adjoint(xs::TrackedArray) = track(adjoint, xs)
det(xs::TrackedArray) = track(det, xs)
@grad det(xs) = det(data(xs)), Δ -> (Δ * det(xs) * transpose(inv(xs)),)
logdet(xs::TrackedArray) = track(logdet, xs)
@grad logdet(xs) = logdet(data(xs)), Δ -> (Δ * transpose(inv(xs)),)
logabsdet(xs::TrackedArray) = track(logabsdet, xs)
@grad logabsdet(xs) = logabsdet(data(xs)), Δ -> (Δ[1] * transpose(inv(xs)),)
Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...)
@grad function repeat(xs; inner=ntuple(x->1, ndims(xs)), outer=ntuple(x->1, ndims(xs)))

View File

@ -3,7 +3,7 @@ using Flux.Tracker, Test, NNlib
using Flux.Tracker: TrackedReal, gradient, gradcheck, grad, checkpoint, forwarddiff
using NNlib: conv, depthwiseconv
using Printf: @sprintf
using LinearAlgebra: diagm, dot, LowerTriangular, norm
using LinearAlgebra: diagm, dot, LowerTriangular, norm, det, logdet, logabsdet
using Statistics: mean, std
using Random
# using StatsBase
@ -34,6 +34,10 @@ gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
@test gradtest(x -> x', rand(5))
@test gradtest(det, (4, 4))
@test gradtest(logdet, (4, 4))
@test gradtest((x) -> logabsdet(x)[1], (4, 4))
@testset "indexing & slicing" begin
gradtest(x->view(x, 1:2, 1:2), rand(4, 4))
end