diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 6d3c3b3f..3d9836d0 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -1,6 +1,8 @@ import Base: * import LinearAlgebra +import LinearAlgebra: inv, \, / + using Statistics using LinearAlgebra: Transpose, Adjoint, diagm, diag @@ -205,6 +207,41 @@ Base.kron(a::TrackedMatrix, b::TrackedMatrix) = _kron(a, b) Base.kron(a::TrackedMatrix, b::AbstractMatrix) = _kron(a, b) Base.kron(a::AbstractMatrix, b::TrackedMatrix) = _kron(a, b) + +inv(A::TrackedArray) = Tracker.track(inv, A) +@grad function inv(A) + return inv(Tracker.data(A)), function (Δ) + Ainv = inv(A) + ∇A = - Ainv' * Δ * Ainv' + return (∇A, ) + end +end + +# (/) rdivide +A::TrackedArray / B::TrackedArray = Tracker.track(/, A, B) +A::AbstractVecOrMat / B::TrackedArray = Tracker.track(/, A, B) +A::TrackedArray / B::AbstractVecOrMat = Tracker.track(/, A, B) +@grad function (A / B) + return Tracker.data(A) / Tracker.data(B), function (Δ) + Binv = inv(B) + ∇B = - Binv' * A' * Δ * Binv' + return (Δ * Binv', ∇B) + end +end + +# (\) ldivide (left vec divide needs more work to resolve dispatch ambiguity) +A::TrackedArray \ B::TrackedArray = Tracker.track(\, A, B) +A::AbstractArray \ B::TrackedArray = Tracker.track(\, A, B) +A::TrackedArray \ B::AbstractVecOrMat = Tracker.track(\, A, B) +@grad function (A \ B) + return Tracker.data(A) \ Tracker.data(B), function (Δ) + Ainv = inv(A) + ∇A = - Ainv' * Δ * B' * Ainv' + return (∇A, Ainv' * Δ) + end +end + + # Reductions Base.sum(xs::TrackedArray; dims = :) = track(sum, xs, dims = dims) diff --git a/test/tracker.jl b/test/tracker.jl index 9a4cb793..a4772f2e 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -129,6 +129,11 @@ end @test gradtest(f-> Matrix(Diagonal(f)), rand(3)) +@test gradtest(W -> inv(log.(W * W)), (5,5)) +@test gradtest((A, B) -> A / B , (1,5), (5,5)) +@test gradtest((A, B) -> log.(A * A) / exp.(B * B), (5,5), (5,5)) +@test gradtest((A, B) -> log.(A * A) \ exp.(B * B), (5,5), (5,5)) + @testset "mean" begin @test gradtest(mean, rand(2, 3))