diff --git a/src/tracker/array.jl b/src/tracker/array.jl index de950b99..d5c04b5c 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -1,3 +1,5 @@ +import Base: *, == + using LinearAlgebra struct TrackedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N} @@ -56,9 +58,9 @@ Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) = Base.similar(x::TrackedArray, T::Type) = similar(data(x), T) -Base.:(==)(x::TrackedArray, y) = data(x) == y -Base.:(==)(y, x::TrackedArray) = y == data(x) -Base.:(==)(x::TrackedArray, y::TrackedArray) = data(x) == data(y) +x::TrackedArray == y = data(x) == y +y == x::TrackedArray = y == data(x) +x::TrackedArray == y::TrackedArray = data(x) == data(y) # Array Stdlib @@ -77,10 +79,10 @@ Base.:-(xs::TrackedArray) = track(-, xs) @grad -(xs) = -data(xs), Δ -> (-Δ,) Base.transpose(xs::TrackedArray) = track(transpose, xs) -Base.ctranspose(xs::TrackedArray) = track(ctranspose, xs) +Base.adjoint(xs::TrackedArray) = track(adjoint, xs) @grad transpose(xs) = transpose(data(xs)), Δ -> (reshape(transpose(Δ), size(xs)),) -@grad ctranspose(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),) +@grad adjoint(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),) Base.repeat(A::TrackedArray; kw...) = track_kw(repeat, A; kw...) @@ -269,31 +271,20 @@ end LinearAlgebra.diagm(x::TrackedVector) = track(diagm, x) @grad diagm(x) = diagm(data(x)), Δ -> (diag(Δ),) -for f in :[*, Ac_mul_B, A_mul_Bc, A_mul_Bt, At_mul_B].args - @eval begin - import Base.$f - $f(a::TrackedMatrix, b::TrackedMatrix) = track($f, a, b) - $f(a::TrackedMatrix, b::AbstractMatrix) = track($f, a, b) - $f(a::AbstractMatrix, b::TrackedMatrix) = track($f, a, b) +x::TrackedMatrix * y::AbstractMatrix = track(*, x, y) +y::AbstractMatrix * x::TrackedMatrix = track(*, x, y) +x::TrackedMatrix * y::TrackedMatrix = track(*, x, y) - $f(a::TrackedMatrix, b::TrackedVector) = track($f, a, b) - $f(a::TrackedMatrix, b::AbstractVector) = track($f, a, b) - $f(a::AbstractMatrix, b::TrackedVector) = track($f, a, b) +x::TrackedMatrix * y::AbstractVector = track(*, x, y) +y::AbstractMatrix * x::TrackedVector = track(*, x, y) +x::TrackedMatrix * y::TrackedVector = track(*, x, y) - $f(a::TrackedVector, b::TrackedVector) = track($f, a, b) - $f(a::TrackedVector, b::AbstractVector) = track($f, a, b) - $f(a::AbstractVector, b::TrackedVector) = track($f, a, b) - end -end +x::TrackedVector * y::AbstractVector = track(*, x, y) +y::AbstractVector * x::TrackedVector = track(*, x, y) +x::TrackedVector * y::TrackedVector = track(*, x, y) @grad a::AbstractMatrix * b::AbstractVecOrMat = - data(a)*data(b), Δ -> (A_mul_Bt(Δ, b), At_mul_B(a, Δ)) - -@grad Ac_mul_B(a, b) = Ac_mul_B(data(a), data(b)), Δ -> (A_mul_Bt(Δ, b)', a*Δ) -@grad A_mul_Bc(a, b) = A_mul_Bc(data(a), data(b)), Δ -> (Δ * b, At_mul_B(a, Δ)') - -@grad At_mul_B(a, b) = At_mul_B(data(a), data(b)), Δ -> (A_mul_Bt(Δ, b)', a*Δ) -@grad A_mul_Bt(a, b) = A_mul_Bt(data(a), data(b)), Δ -> (Δ * b, At_mul_B(a, Δ)') + data(a)*data(b), Δ -> (Δ * transpose(b), transpose(a) * Δ) # NNlib