diff --git a/src/Tracker/lib.jl b/src/Tracker/lib.jl index 0e8c573f..3cfdd382 100644 --- a/src/Tracker/lib.jl +++ b/src/Tracker/lib.jl @@ -14,7 +14,11 @@ a::TrackedMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b)) a::TrackedMatrix * b::AbstractMatrix = TrackedArray(Call(*, a, b)) a::AbstractMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b)) -function back!(::typeof(*), Δ, a::AbstractMatrix, b::AbstractMatrix) +a::TrackedMatrix * b::TrackedVector = TrackedArray(Call(*, a, b)) +a::TrackedMatrix * b::AbstractVector = TrackedArray(Call(*, a, b)) +a::AbstractMatrix * b::TrackedVector = TrackedArray(Call(*, a, b)) + +function back!(::typeof(*), Δ, a::AbstractMatrix, b::AbstractVecOrMat) @back!(a, A_mul_Bt(Δ, data(b))) @back!(b, At_mul_B(data(a), Δ)) end