add inv/ldivide/rdivide + test
This commit is contained in:
parent
b3a08baf55
commit
d131853587
@ -1,6 +1,8 @@
|
|||||||
import Base: *
|
import Base: *
|
||||||
|
|
||||||
import LinearAlgebra
|
import LinearAlgebra
|
||||||
|
import LinearAlgebra: inv, \, /
|
||||||
|
|
||||||
using Statistics
|
using Statistics
|
||||||
using LinearAlgebra: Transpose, Adjoint, diagm, diag
|
using LinearAlgebra: Transpose, Adjoint, diagm, diag
|
||||||
|
|
||||||
@ -205,6 +207,41 @@ Base.kron(a::TrackedMatrix, b::TrackedMatrix) = _kron(a, b)
|
|||||||
Base.kron(a::TrackedMatrix, b::AbstractMatrix) = _kron(a, b)
|
Base.kron(a::TrackedMatrix, b::AbstractMatrix) = _kron(a, b)
|
||||||
Base.kron(a::AbstractMatrix, b::TrackedMatrix) = _kron(a, b)
|
Base.kron(a::AbstractMatrix, b::TrackedMatrix) = _kron(a, b)
|
||||||
|
|
||||||
|
|
||||||
|
inv(A::TrackedArray) = Tracker.track(inv, A)
|
||||||
|
@grad function inv(A)
|
||||||
|
return inv(Tracker.data(A)), function (Δ)
|
||||||
|
Ainv = inv(A)
|
||||||
|
∇A = - Ainv' * Δ * Ainv'
|
||||||
|
return (∇A, )
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
# (/) rdivide
|
||||||
|
A::TrackedArray / B::TrackedArray = Tracker.track(/, A, B)
|
||||||
|
A::AbstractVecOrMat / B::TrackedArray = Tracker.track(/, A, B)
|
||||||
|
A::TrackedArray / B::AbstractVecOrMat = Tracker.track(/, A, B)
|
||||||
|
@grad function (A / B)
|
||||||
|
return Tracker.data(A) / Tracker.data(B), function (Δ)
|
||||||
|
Binv = inv(B)
|
||||||
|
∇B = - Binv' * A' * Δ * Binv'
|
||||||
|
return (Δ * Binv', ∇B)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
# (\) ldivide (left vec divide needs more work to resolve dispatch ambiguity)
|
||||||
|
A::TrackedArray \ B::TrackedArray = Tracker.track(\, A, B)
|
||||||
|
A::AbstractArray \ B::TrackedArray = Tracker.track(\, A, B)
|
||||||
|
A::TrackedArray \ B::AbstractVecOrMat = Tracker.track(\, A, B)
|
||||||
|
@grad function (A \ B)
|
||||||
|
return Tracker.data(A) \ Tracker.data(B), function (Δ)
|
||||||
|
Ainv = inv(A)
|
||||||
|
∇A = - Ainv' * Δ * B' * Ainv'
|
||||||
|
return (∇A, Ainv' * Δ)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
|
||||||
# Reductions
|
# Reductions
|
||||||
|
|
||||||
Base.sum(xs::TrackedArray; dims = :) = track(sum, xs, dims = dims)
|
Base.sum(xs::TrackedArray; dims = :) = track(sum, xs, dims = dims)
|
||||||
|
@ -129,6 +129,11 @@ end
|
|||||||
|
|
||||||
@test gradtest(f-> Matrix(Diagonal(f)), rand(3))
|
@test gradtest(f-> Matrix(Diagonal(f)), rand(3))
|
||||||
|
|
||||||
|
@test gradtest(W -> inv(log.(W * W)), (5,5))
|
||||||
|
@test gradtest((A, B) -> A / B , (1,5), (5,5))
|
||||||
|
@test gradtest((A, B) -> log.(A * A) / exp.(B * B), (5,5), (5,5))
|
||||||
|
@test gradtest((A, B) -> log.(A * A) \ exp.(B * B), (5,5), (5,5))
|
||||||
|
|
||||||
@testset "mean" begin
|
@testset "mean" begin
|
||||||
@test gradtest(mean, rand(2, 3))
|
@test gradtest(mean, rand(2, 3))
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user