matrix-vector fast path
This commit is contained in:
parent
efa51f02e7
commit
d6423eefe5
@ -38,6 +38,8 @@ TrackedArray(c::Call) = TrackedArray(c, c())
|
||||
|
||||
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
|
||||
|
||||
isleaf(x::TrackedArray) = x.f == Call(nothing)
|
||||
|
||||
param(xs) = TrackedArray(AbstractFloat.(xs))
|
||||
param(xs::Real) = param(fill(xs))
|
||||
|
||||
|
@ -79,6 +79,16 @@ function back(::typeof(*), Δ, a::AbstractMatrix, b::AbstractVecOrMat)
|
||||
@back(b, At_mul_B(data(a), Δ))
|
||||
end
|
||||
|
||||
# Fast path for matrix-vector
|
||||
function back(::typeof(*), Δ::AbstractVector, W::TrackedMatrix, x::AbstractVector)
|
||||
if isleaf(W)
|
||||
W.grad .+= Δ .* data(x).'
|
||||
else
|
||||
back(W, A_mul_Bt(Δ, data(x)))
|
||||
end
|
||||
@back(x, At_mul_B(data(W), Δ))
|
||||
end
|
||||
|
||||
# NNlib
|
||||
|
||||
import NNlib: softmax, ∇softmax
|
||||
|
Loading…
Reference in New Issue
Block a user