fix matmul
This commit is contained in:
parent
7057ca739e
commit
fb8a220659
@ -1,3 +1,5 @@
|
||||
import Base: *, ==
|
||||
|
||||
using LinearAlgebra
|
||||
|
||||
struct TrackedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
|
||||
@ -56,9 +58,9 @@ Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
|
||||
|
||||
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
|
||||
|
||||
Base.:(==)(x::TrackedArray, y) = data(x) == y
|
||||
Base.:(==)(y, x::TrackedArray) = y == data(x)
|
||||
Base.:(==)(x::TrackedArray, y::TrackedArray) = data(x) == data(y)
|
||||
x::TrackedArray == y = data(x) == y
|
||||
y == x::TrackedArray = y == data(x)
|
||||
x::TrackedArray == y::TrackedArray = data(x) == data(y)
|
||||
|
||||
# Array Stdlib
|
||||
|
||||
@ -77,10 +79,10 @@ Base.:-(xs::TrackedArray) = track(-, xs)
|
||||
@grad -(xs) = -data(xs), Δ -> (-Δ,)
|
||||
|
||||
Base.transpose(xs::TrackedArray) = track(transpose, xs)
|
||||
Base.ctranspose(xs::TrackedArray) = track(ctranspose, xs)
|
||||
Base.adjoint(xs::TrackedArray) = track(adjoint, xs)
|
||||
|
||||
@grad transpose(xs) = transpose(data(xs)), Δ -> (reshape(transpose(Δ), size(xs)),)
|
||||
@grad ctranspose(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),)
|
||||
@grad adjoint(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),)
|
||||
|
||||
Base.repeat(A::TrackedArray; kw...) = track_kw(repeat, A; kw...)
|
||||
|
||||
@ -269,31 +271,20 @@ end
|
||||
LinearAlgebra.diagm(x::TrackedVector) = track(diagm, x)
|
||||
@grad diagm(x) = diagm(data(x)), Δ -> (diag(Δ),)
|
||||
|
||||
for f in :[*, Ac_mul_B, A_mul_Bc, A_mul_Bt, At_mul_B].args
|
||||
@eval begin
|
||||
import Base.$f
|
||||
$f(a::TrackedMatrix, b::TrackedMatrix) = track($f, a, b)
|
||||
$f(a::TrackedMatrix, b::AbstractMatrix) = track($f, a, b)
|
||||
$f(a::AbstractMatrix, b::TrackedMatrix) = track($f, a, b)
|
||||
x::TrackedMatrix * y::AbstractMatrix = track(*, x, y)
|
||||
y::AbstractMatrix * x::TrackedMatrix = track(*, x, y)
|
||||
x::TrackedMatrix * y::TrackedMatrix = track(*, x, y)
|
||||
|
||||
$f(a::TrackedMatrix, b::TrackedVector) = track($f, a, b)
|
||||
$f(a::TrackedMatrix, b::AbstractVector) = track($f, a, b)
|
||||
$f(a::AbstractMatrix, b::TrackedVector) = track($f, a, b)
|
||||
x::TrackedMatrix * y::AbstractVector = track(*, x, y)
|
||||
y::AbstractMatrix * x::TrackedVector = track(*, x, y)
|
||||
x::TrackedMatrix * y::TrackedVector = track(*, x, y)
|
||||
|
||||
$f(a::TrackedVector, b::TrackedVector) = track($f, a, b)
|
||||
$f(a::TrackedVector, b::AbstractVector) = track($f, a, b)
|
||||
$f(a::AbstractVector, b::TrackedVector) = track($f, a, b)
|
||||
end
|
||||
end
|
||||
x::TrackedVector * y::AbstractVector = track(*, x, y)
|
||||
y::AbstractVector * x::TrackedVector = track(*, x, y)
|
||||
x::TrackedVector * y::TrackedVector = track(*, x, y)
|
||||
|
||||
@grad a::AbstractMatrix * b::AbstractVecOrMat =
|
||||
data(a)*data(b), Δ -> (A_mul_Bt(Δ, b), At_mul_B(a, Δ))
|
||||
|
||||
@grad Ac_mul_B(a, b) = Ac_mul_B(data(a), data(b)), Δ -> (A_mul_Bt(Δ, b)', a*Δ)
|
||||
@grad A_mul_Bc(a, b) = A_mul_Bc(data(a), data(b)), Δ -> (Δ * b, At_mul_B(a, Δ)')
|
||||
|
||||
@grad At_mul_B(a, b) = At_mul_B(data(a), data(b)), Δ -> (A_mul_Bt(Δ, b)', a*Δ)
|
||||
@grad A_mul_Bt(a, b) = A_mul_Bt(data(a), data(b)), Δ -> (Δ * b, At_mul_B(a, Δ)')
|
||||
data(a)*data(b), Δ -> (Δ * transpose(b), transpose(a) * Δ)
|
||||
|
||||
# NNlib
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user