diff --git a/src/tracker/lib.jl b/src/tracker/lib.jl index 5065a40d..ab250e39 100644 --- a/src/tracker/lib.jl +++ b/src/tracker/lib.jl @@ -70,7 +70,7 @@ back(::typeof(mean), Δ, xs::TrackedArray, region) = # BLAS -for f in :[*, Ac_mul_B].args +for f in :[*, Ac_mul_B, A_mul_Bc].args @eval begin import Base.$f $f(a::TrackedMatrix, b::TrackedMatrix) = TrackedArray(Call($f, a, b)) @@ -94,7 +94,12 @@ end function back(::typeof(Ac_mul_B), Δ, a::AbstractVecOrMat{<:Real}, b::AbstractVecOrMat{<:Real}) @back(a, A_mul_Bt(Δ, data(b))') - @back(b, *(data(a), Δ)) + @back(b, data(a)*Δ) +end + +function back(::typeof(A_mul_Bc), Δ, a::AbstractVecOrMat{<:Real}, b::AbstractVecOrMat{<:Real}) + @back(a, Δ * data(b)) + @back(b, At_mul_B(data(a), Δ)') end # Fast path for matrix-vector diff --git a/test/tracker.jl b/test/tracker.jl index 81a72566..7d9ef4f5 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -10,6 +10,7 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...) @test gradtest((x, W, b) -> σ.(W*x .+ b), (5,3), (2,5), 2) @test gradtest((w, x) -> w'*x, randn(10, 2), randn(10)) +@test gradtest((w, x) -> w*x', randn(5,5), randn(5,5)) @test gradtest(x -> sin.(sum(x, (2, 3))), (3,4,5))