Fixed issue #542.
Added tracking of LinearAlgebra.det and its grad method.
This commit is contained in:
parent
940b1e6dbf
commit
aa64d2157d
@ -1,7 +1,7 @@
|
|||||||
import Base: *
|
import Base: *
|
||||||
|
|
||||||
import LinearAlgebra
|
import LinearAlgebra
|
||||||
import LinearAlgebra: inv, \, /
|
import LinearAlgebra: inv, det, \, /
|
||||||
|
|
||||||
using Statistics
|
using Statistics
|
||||||
using LinearAlgebra: Transpose, Adjoint, diagm, diag
|
using LinearAlgebra: Transpose, Adjoint, diagm, diag
|
||||||
@ -124,6 +124,9 @@ Base.adjoint(xs::TrackedArray) = track(adjoint, xs)
|
|||||||
@grad transpose(xs) = transpose(data(xs)), Δ -> (trim(xs, transpose(Δ)),)
|
@grad transpose(xs) = transpose(data(xs)), Δ -> (trim(xs, transpose(Δ)),)
|
||||||
@grad adjoint(xs) = data(xs)', Δ -> (trim(xs, Δ'),)
|
@grad adjoint(xs) = data(xs)', Δ -> (trim(xs, Δ'),)
|
||||||
|
|
||||||
|
det(xs::TrackedArray) = track(det, xs)
|
||||||
|
@grad det(xs) = det(data(xs)), Δ -> (Δ * transpose(adjoint(xs)),)
|
||||||
|
|
||||||
Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...)
|
Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...)
|
||||||
|
|
||||||
@grad function repeat(xs; inner=ntuple(x->1, ndims(xs)), outer=ntuple(x->1, ndims(xs)))
|
@grad function repeat(xs; inner=ntuple(x->1, ndims(xs)), outer=ntuple(x->1, ndims(xs)))
|
||||||
|
Loading…
Reference in New Issue
Block a user