diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 5b8ddd13..7111d780 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -276,20 +276,29 @@ LinearAlgebra.diagm(x::TrackedVector) = track(diagm, x) @grad diagm(x) = diagm(data(x)), Δ -> (diag(Δ),) x::TrackedMatrix * y::AbstractMatrix = track(*, x, y) -y::AbstractMatrix * x::TrackedMatrix = track(*, x, y) +x::AbstractMatrix * y::TrackedMatrix = track(*, x, y) x::TrackedMatrix * y::TrackedMatrix = track(*, x, y) x::TrackedMatrix * y::AbstractVector = track(*, x, y) -y::AbstractMatrix * x::TrackedVector = track(*, x, y) +x::AbstractMatrix * y::TrackedVector = track(*, x, y) x::TrackedMatrix * y::TrackedVector = track(*, x, y) x::TrackedVector * y::AbstractVector = track(*, x, y) -y::AbstractVector * x::TrackedVector = track(*, x, y) +x::AbstractVector * y::TrackedVector = track(*, x, y) x::TrackedVector * y::TrackedVector = track(*, x, y) @grad a::AbstractMatrix * b::AbstractVecOrMat = data(a)*data(b), Δ -> (Δ * transpose(b), transpose(a) * Δ) +# @grad function (a::AbstractMatrix * b::AbstractVecOrMat) +# # @show size(a) size(b) +# data(a)*data(b), function (Δ) +# @show size(Δ) size(b) size(Δ*transpose(b)) size(Δ*transpose(data(b))) +# @show typeof(Δ) typeof(b) +# (Δ * transpose(b), transpose(a) * Δ) +# end +# end + # NNlib using NNlib