diff --git a/src/tracker/lib/array.jl b/src/tracker/lib/array.jl index 6a2ab965..096a2727 100644 --- a/src/tracker/lib/array.jl +++ b/src/tracker/lib/array.jl @@ -376,6 +376,9 @@ x::TrackedVector * y::TrackedVector = track(*, x, y) @grad a::AbstractMatrix * b::AbstractVecOrMat = data(a)*data(b), Δ -> (Δ * transpose(b), transpose(a) * Δ) +@grad a::TrackedMatrix{<:Any, <:Transpose} * b::AbstractVecOrMat = + data(a)*data(b), Δ -> (transpose(b * transpose(Δ)), transpose(a) * Δ) + # NNlib using NNlib @@ -437,7 +440,7 @@ end using ForwardDiff: Dual, partials, value -trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x)))) +trim(x, Δ) = ndims(Δ) == ndims(x) ? Δ : reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x)))) function unbroadcast(x::AbstractArray, Δ) if size(x) == size(Δ) diff --git a/test/tracker.jl b/test/tracker.jl index 5ed61120..c67b2d06 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -356,4 +356,14 @@ end @test back([1, 1]) == (32,) end +@testset "transpose" begin + let f = (x,a,b)->(x = transpose(x); x * a + x * b), + g = x->(a = transpose(x); b = transpose(a); b * [1.0 1.0; 2.0 3.0] + a * [1.0 1.0; 2.0 3.0]) + @test gradient(x->sum(f(x, [1.0; 1.0], [1.0; 1.0])), [1.0 1.0; 1.0 1.0])[1] == + [2.0 2.0; 2.0 2.0] + @test gradient(x->sum(g(x)), [1.0 1.0; 1.0 1.0])[1] == + [4.0 7.0; 7.0 10.0] + end +end + end #testset