Merge pull request #596 from IvanYashchuk/ivan/topic-issue-542
Fixed issue #542.
This commit is contained in:
commit
fe712bf338
@ -1,7 +1,7 @@
|
|||||||
import Base: *
|
import Base: *
|
||||||
|
|
||||||
import LinearAlgebra
|
import LinearAlgebra
|
||||||
import LinearAlgebra: inv, \, /
|
import LinearAlgebra: inv, det, logdet, logabsdet, \, /
|
||||||
|
|
||||||
using Statistics
|
using Statistics
|
||||||
using LinearAlgebra: Transpose, Adjoint, diagm, diag
|
using LinearAlgebra: Transpose, Adjoint, diagm, diag
|
||||||
@ -124,6 +124,15 @@ Base.adjoint(xs::TrackedArray) = track(adjoint, xs)
|
|||||||
@grad transpose(xs) = transpose(data(xs)), Δ -> (trim(xs, transpose(Δ)),)
|
@grad transpose(xs) = transpose(data(xs)), Δ -> (trim(xs, transpose(Δ)),)
|
||||||
@grad adjoint(xs) = data(xs)', Δ -> (trim(xs, Δ'),)
|
@grad adjoint(xs) = data(xs)', Δ -> (trim(xs, Δ'),)
|
||||||
|
|
||||||
|
det(xs::TrackedArray) = track(det, xs)
|
||||||
|
@grad det(xs) = det(data(xs)), Δ -> (Δ * det(xs) * transpose(inv(xs)),)
|
||||||
|
|
||||||
|
logdet(xs::TrackedArray) = track(logdet, xs)
|
||||||
|
@grad logdet(xs) = logdet(data(xs)), Δ -> (Δ * transpose(inv(xs)),)
|
||||||
|
|
||||||
|
logabsdet(xs::TrackedArray) = track(logabsdet, xs)
|
||||||
|
@grad logabsdet(xs) = logabsdet(data(xs)), Δ -> (Δ[1] * transpose(inv(xs)),)
|
||||||
|
|
||||||
Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...)
|
Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...)
|
||||||
|
|
||||||
@grad function repeat(xs; inner=ntuple(x->1, ndims(xs)), outer=ntuple(x->1, ndims(xs)))
|
@grad function repeat(xs; inner=ntuple(x->1, ndims(xs)), outer=ntuple(x->1, ndims(xs)))
|
||||||
|
@ -3,7 +3,7 @@ using Flux.Tracker, Test, NNlib
|
|||||||
using Flux.Tracker: TrackedReal, gradient, gradcheck, grad, checkpoint, forwarddiff
|
using Flux.Tracker: TrackedReal, gradient, gradcheck, grad, checkpoint, forwarddiff
|
||||||
using NNlib: conv, ∇conv_data, depthwiseconv
|
using NNlib: conv, ∇conv_data, depthwiseconv
|
||||||
using Printf: @sprintf
|
using Printf: @sprintf
|
||||||
using LinearAlgebra: diagm, dot, LowerTriangular, norm
|
using LinearAlgebra: diagm, dot, LowerTriangular, norm, det, logdet, logabsdet
|
||||||
using Statistics: mean, std
|
using Statistics: mean, std
|
||||||
using Random
|
using Random
|
||||||
# using StatsBase
|
# using StatsBase
|
||||||
@ -34,6 +34,10 @@ gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
|
|||||||
|
|
||||||
@test gradtest(x -> x', rand(5))
|
@test gradtest(x -> x', rand(5))
|
||||||
|
|
||||||
|
@test gradtest(det, (4, 4))
|
||||||
|
@test gradtest(logdet, map((x) -> x*x', (rand(4, 4),))[1])
|
||||||
|
@test gradtest((x) -> logabsdet(x)[1], (4, 4))
|
||||||
|
|
||||||
@testset "indexing & slicing" begin
|
@testset "indexing & slicing" begin
|
||||||
gradtest(x->view(x, 1:2, 1:2), rand(4, 4))
|
gradtest(x->view(x, 1:2, 1:2), rand(4, 4))
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user