Handle various cases of multiplying transpose-wrapped matrices
See test cases. I hit these while taking third-order derivatives of matrix multiplies (whose gradient definitions use transpose).
This commit is contained in:
parent
53875a85a1
commit
dfb4e2e8ab
|
@ -361,6 +361,9 @@ x::TrackedVector * y::TrackedVector = track(*, x, y)
|
|||
@grad a::AbstractMatrix * b::AbstractVecOrMat =
|
||||
data(a)*data(b), Δ -> (Δ * transpose(b), transpose(a) * Δ)
|
||||
|
||||
@grad a::TrackedMatrix{<:Any, <:Transpose} * b::AbstractVecOrMat =
|
||||
data(a)*data(b), Δ -> (transpose(b * transpose(Δ)), transpose(a) * Δ)
|
||||
|
||||
# NNlib
|
||||
|
||||
using NNlib
|
||||
|
@ -412,7 +415,7 @@ end
|
|||
|
||||
using ForwardDiff: Dual, partials, value
|
||||
|
||||
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
|
||||
trim(x, Δ) = ndims(Δ) == ndims(x) ? Δ : reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
|
||||
|
||||
unbroadcast(x::AbstractArray, Δ) =
|
||||
size(x) == size(Δ) ? Δ :
|
||||
|
|
|
@ -328,4 +328,14 @@ end
|
|||
@test back([1, 1]) == (32,)
|
||||
end
|
||||
|
||||
@testset "transpose" begin
|
||||
let f = (x,a,b)->(x = transpose(x); x * a + x * b),
|
||||
g = x->(a = transpose(x); b = transpose(a); b * [1.0 1.0; 2.0 3.0] + a * [1.0 1.0; 2.0 3.0])
|
||||
@test gradient(x->sum(f(x, [1.0; 1.0], [1.0; 1.0])), [1.0 1.0; 1.0 1.0])[1] ==
|
||||
[2.0 2.0; 2.0 2.0]
|
||||
@test gradient(x->sum(g(x)), [1.0 1.0; 1.0 1.0])[1] ==
|
||||
[4.0 7.0; 7.0 10.0]
|
||||
end
|
||||
end
|
||||
|
||||
end #testset
|
||||
|
|
Loading…
Reference in New Issue