matrix-vector fast path

This commit is contained in:
Mike J Innes 2017-11-07 19:34:27 +00:00
parent efa51f02e7
commit d6423eefe5
2 changed files with 12 additions and 0 deletions

View File

@ -38,6 +38,8 @@ TrackedArray(c::Call) = TrackedArray(c, c())
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
isleaf(x::TrackedArray) = x.f == Call(nothing)
param(xs) = TrackedArray(AbstractFloat.(xs))
param(xs::Real) = param(fill(xs))

View File

@ -79,6 +79,16 @@ function back(::typeof(*), Δ, a::AbstractMatrix, b::AbstractVecOrMat)
@back(b, At_mul_B(data(a), Δ))
end
# Fast path for matrix-vector
function back(::typeof(*), Δ::AbstractVector, W::TrackedMatrix, x::AbstractVector)
if isleaf(W)
W.grad .+= Δ .* data(x).'
else
back(W, A_mul_Bt(Δ, data(x)))
end
@back(x, At_mul_B(data(W), Δ))
end
# NNlib
import NNlib: softmax, ∇softmax