add inv/ldivide/rdivide + test
This commit is contained in:
parent
b3a08baf55
commit
d131853587
|
@ -1,6 +1,8 @@
|
|||
import Base: *
|
||||
|
||||
import LinearAlgebra
|
||||
import LinearAlgebra: inv, \, /
|
||||
|
||||
using Statistics
|
||||
using LinearAlgebra: Transpose, Adjoint, diagm, diag
|
||||
|
||||
|
@ -205,6 +207,41 @@ Base.kron(a::TrackedMatrix, b::TrackedMatrix) = _kron(a, b)
|
|||
Base.kron(a::TrackedMatrix, b::AbstractMatrix) = _kron(a, b)
|
||||
Base.kron(a::AbstractMatrix, b::TrackedMatrix) = _kron(a, b)
|
||||
|
||||
|
||||
inv(A::TrackedArray) = Tracker.track(inv, A)
|
||||
@grad function inv(A)
|
||||
return inv(Tracker.data(A)), function (Δ)
|
||||
Ainv = inv(A)
|
||||
∇A = - Ainv' * Δ * Ainv'
|
||||
return (∇A, )
|
||||
end
|
||||
end
|
||||
|
||||
# (/) rdivide
|
||||
A::TrackedArray / B::TrackedArray = Tracker.track(/, A, B)
|
||||
A::AbstractVecOrMat / B::TrackedArray = Tracker.track(/, A, B)
|
||||
A::TrackedArray / B::AbstractVecOrMat = Tracker.track(/, A, B)
|
||||
@grad function (A / B)
|
||||
return Tracker.data(A) / Tracker.data(B), function (Δ)
|
||||
Binv = inv(B)
|
||||
∇B = - Binv' * A' * Δ * Binv'
|
||||
return (Δ * Binv', ∇B)
|
||||
end
|
||||
end
|
||||
|
||||
# (\) ldivide (left vec divide needs more work to resolve dispatch ambiguity)
|
||||
A::TrackedArray \ B::TrackedArray = Tracker.track(\, A, B)
|
||||
A::AbstractArray \ B::TrackedArray = Tracker.track(\, A, B)
|
||||
A::TrackedArray \ B::AbstractVecOrMat = Tracker.track(\, A, B)
|
||||
@grad function (A \ B)
|
||||
return Tracker.data(A) \ Tracker.data(B), function (Δ)
|
||||
Ainv = inv(A)
|
||||
∇A = - Ainv' * Δ * B' * Ainv'
|
||||
return (∇A, Ainv' * Δ)
|
||||
end
|
||||
end
|
||||
|
||||
|
||||
# Reductions
|
||||
|
||||
Base.sum(xs::TrackedArray; dims = :) = track(sum, xs, dims = dims)
|
||||
|
|
|
@ -129,6 +129,11 @@ end
|
|||
|
||||
@test gradtest(f-> Matrix(Diagonal(f)), rand(3))
|
||||
|
||||
@test gradtest(W -> inv(log.(W * W)), (5,5))
|
||||
@test gradtest((A, B) -> A / B , (1,5), (5,5))
|
||||
@test gradtest((A, B) -> log.(A * A) / exp.(B * B), (5,5), (5,5))
|
||||
@test gradtest((A, B) -> log.(A * A) \ exp.(B * B), (5,5), (5,5))
|
||||
|
||||
@testset "mean" begin
|
||||
@test gradtest(mean, rand(2, 3))
|
||||
|
||||
|
|
Loading…
Reference in New Issue