Handle various cases of multiplying transpose-wrapped matrices

See test cases. I hit these while taking third-order derivatives of
matrix multiplies (whose gradient definitions use transpose).
This commit is contained in:
Keno Fischer 2019-02-05 21:28:44 -05:00
parent 53875a85a1
commit dfb4e2e8ab
2 changed files with 14 additions and 1 deletions

View File

@ -361,6 +361,9 @@ x::TrackedVector * y::TrackedVector = track(*, x, y)
@grad a::AbstractMatrix * b::AbstractVecOrMat =
data(a)*data(b), Δ -> (Δ * transpose(b), transpose(a) * Δ)
@grad a::TrackedMatrix{<:Any, <:Transpose} * b::AbstractVecOrMat =
data(a)*data(b), Δ -> (transpose(b * transpose(Δ)), transpose(a) * Δ)
# NNlib
using NNlib
@ -412,7 +415,7 @@ end
using ForwardDiff: Dual, partials, value
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
trim(x, Δ) = ndims(Δ) == ndims(x) ? Δ : reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
unbroadcast(x::AbstractArray, Δ) =
size(x) == size(Δ) ? Δ :

View File

@ -328,4 +328,14 @@ end
@test back([1, 1]) == (32,)
end
@testset "transpose" begin
let f = (x,a,b)->(x = transpose(x); x * a + x * b),
g = x->(a = transpose(x); b = transpose(a); b * [1.0 1.0; 2.0 3.0] + a * [1.0 1.0; 2.0 3.0])
@test gradient(x->sum(f(x, [1.0; 1.0], [1.0; 1.0])), [1.0 1.0; 1.0 1.0])[1] ==
[2.0 2.0; 2.0 2.0]
@test gradient(x->sum(g(x)), [1.0 1.0; 1.0 1.0])[1] ==
[4.0 7.0; 7.0 10.0]
end
end
end #testset