fix matmul

This commit is contained in:
Mike J Innes 2018-06-20 14:11:31 +01:00
parent 7057ca739e
commit fb8a220659

View File

@ -1,3 +1,5 @@
import Base: *, ==
using LinearAlgebra
struct TrackedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
@ -56,9 +58,9 @@ Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
Base.similar(x::TrackedArray, T::Type) = similar(data(x), T)
Base.:(==)(x::TrackedArray, y) = data(x) == y
Base.:(==)(y, x::TrackedArray) = y == data(x)
Base.:(==)(x::TrackedArray, y::TrackedArray) = data(x) == data(y)
x::TrackedArray == y = data(x) == y
y == x::TrackedArray = y == data(x)
x::TrackedArray == y::TrackedArray = data(x) == data(y)
# Array Stdlib
@ -77,10 +79,10 @@ Base.:-(xs::TrackedArray) = track(-, xs)
@grad -(xs) = -data(xs), Δ -> (-Δ,)
Base.transpose(xs::TrackedArray) = track(transpose, xs)
Base.ctranspose(xs::TrackedArray) = track(ctranspose, xs)
Base.adjoint(xs::TrackedArray) = track(adjoint, xs)
@grad transpose(xs) = transpose(data(xs)), Δ -> (reshape(transpose(Δ), size(xs)),)
@grad ctranspose(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),)
@grad adjoint(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),)
Base.repeat(A::TrackedArray; kw...) = track_kw(repeat, A; kw...)
@ -269,31 +271,20 @@ end
LinearAlgebra.diagm(x::TrackedVector) = track(diagm, x)
@grad diagm(x) = diagm(data(x)), Δ -> (diag(Δ),)
for f in :[*, Ac_mul_B, A_mul_Bc, A_mul_Bt, At_mul_B].args
@eval begin
import Base.$f
$f(a::TrackedMatrix, b::TrackedMatrix) = track($f, a, b)
$f(a::TrackedMatrix, b::AbstractMatrix) = track($f, a, b)
$f(a::AbstractMatrix, b::TrackedMatrix) = track($f, a, b)
x::TrackedMatrix * y::AbstractMatrix = track(*, x, y)
y::AbstractMatrix * x::TrackedMatrix = track(*, x, y)
x::TrackedMatrix * y::TrackedMatrix = track(*, x, y)
$f(a::TrackedMatrix, b::TrackedVector) = track($f, a, b)
$f(a::TrackedMatrix, b::AbstractVector) = track($f, a, b)
$f(a::AbstractMatrix, b::TrackedVector) = track($f, a, b)
x::TrackedMatrix * y::AbstractVector = track(*, x, y)
y::AbstractMatrix * x::TrackedVector = track(*, x, y)
x::TrackedMatrix * y::TrackedVector = track(*, x, y)
$f(a::TrackedVector, b::TrackedVector) = track($f, a, b)
$f(a::TrackedVector, b::AbstractVector) = track($f, a, b)
$f(a::AbstractVector, b::TrackedVector) = track($f, a, b)
end
end
x::TrackedVector * y::AbstractVector = track(*, x, y)
y::AbstractVector * x::TrackedVector = track(*, x, y)
x::TrackedVector * y::TrackedVector = track(*, x, y)
@grad a::AbstractMatrix * b::AbstractVecOrMat =
data(a)*data(b), Δ -> (A_mul_Bt(Δ, b), At_mul_B(a, Δ))
@grad Ac_mul_B(a, b) = Ac_mul_B(data(a), data(b)), Δ -> (A_mul_Bt(Δ, b)', a*Δ)
@grad A_mul_Bc(a, b) = A_mul_Bc(data(a), data(b)), Δ -> (Δ * b, At_mul_B(a, Δ)')
@grad At_mul_B(a, b) = At_mul_B(data(a), data(b)), Δ -> (A_mul_Bt(Δ, b)', a*Δ)
@grad A_mul_Bt(a, b) = A_mul_Bt(data(a), data(b)), Δ -> (Δ * b, At_mul_B(a, Δ)')
data(a)*data(b), Δ -> (Δ * transpose(b), transpose(a) * Δ)
# NNlib