matmul fix

This commit is contained in:
Mike Innes 2018-07-30 17:04:18 +01:00 committed by Mike J Innes
parent 4cf6bac0c1
commit f5c9361617

View File

@ -276,20 +276,29 @@ LinearAlgebra.diagm(x::TrackedVector) = track(diagm, x)
@grad diagm(x) = diagm(data(x)), Δ -> (diag(Δ),) @grad diagm(x) = diagm(data(x)), Δ -> (diag(Δ),)
x::TrackedMatrix * y::AbstractMatrix = track(*, x, y) x::TrackedMatrix * y::AbstractMatrix = track(*, x, y)
y::AbstractMatrix * x::TrackedMatrix = track(*, x, y) x::AbstractMatrix * y::TrackedMatrix = track(*, x, y)
x::TrackedMatrix * y::TrackedMatrix = track(*, x, y) x::TrackedMatrix * y::TrackedMatrix = track(*, x, y)
x::TrackedMatrix * y::AbstractVector = track(*, x, y) x::TrackedMatrix * y::AbstractVector = track(*, x, y)
y::AbstractMatrix * x::TrackedVector = track(*, x, y) x::AbstractMatrix * y::TrackedVector = track(*, x, y)
x::TrackedMatrix * y::TrackedVector = track(*, x, y) x::TrackedMatrix * y::TrackedVector = track(*, x, y)
x::TrackedVector * y::AbstractVector = track(*, x, y) x::TrackedVector * y::AbstractVector = track(*, x, y)
y::AbstractVector * x::TrackedVector = track(*, x, y) x::AbstractVector * y::TrackedVector = track(*, x, y)
x::TrackedVector * y::TrackedVector = track(*, x, y) x::TrackedVector * y::TrackedVector = track(*, x, y)
@grad a::AbstractMatrix * b::AbstractVecOrMat = @grad a::AbstractMatrix * b::AbstractVecOrMat =
data(a)*data(b), Δ -> (Δ * transpose(b), transpose(a) * Δ) data(a)*data(b), Δ -> (Δ * transpose(b), transpose(a) * Δ)
# @grad function (a::AbstractMatrix * b::AbstractVecOrMat)
# # @show size(a) size(b)
# data(a)*data(b), function (Δ)
# @show size(Δ) size(b) size(Δ*transpose(b)) size(Δ*transpose(data(b)))
# @show typeof(Δ) typeof(b)
# (Δ * transpose(b), transpose(a) * Δ)
# end
# end
# NNlib # NNlib
using NNlib using NNlib