Use other definition for grad(det(A)).
This commit is contained in:
parent
aa64d2157d
commit
f790fff59a
@ -125,7 +125,7 @@ Base.adjoint(xs::TrackedArray) = track(adjoint, xs)
|
|||||||
@grad adjoint(xs) = data(xs)', Δ -> (trim(xs, Δ'),)
|
@grad adjoint(xs) = data(xs)', Δ -> (trim(xs, Δ'),)
|
||||||
|
|
||||||
det(xs::TrackedArray) = track(det, xs)
|
det(xs::TrackedArray) = track(det, xs)
|
||||||
@grad det(xs) = det(data(xs)), Δ -> (Δ * transpose(adjoint(xs)),)
|
@grad det(xs) = det(data(xs)), Δ -> (Δ * det(xs) * transpose(inv(xs)),)
|
||||||
|
|
||||||
Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...)
|
Base.repeat(xs::TrackedArray; kw...) = track(repeat, xs; kw...)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user