matrix-vector fast path
This commit is contained in:
parent
efa51f02e7
commit
d6423eefe5
@ -38,6 +38,8 @@ TrackedArray(c::Call) = TrackedArray(c, c())
|
|||||||
|
|
||||||
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
|
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x))
|
||||||
|
|
||||||
|
isleaf(x::TrackedArray) = x.f == Call(nothing)
|
||||||
|
|
||||||
param(xs) = TrackedArray(AbstractFloat.(xs))
|
param(xs) = TrackedArray(AbstractFloat.(xs))
|
||||||
param(xs::Real) = param(fill(xs))
|
param(xs::Real) = param(fill(xs))
|
||||||
|
|
||||||
|
@ -79,6 +79,16 @@ function back(::typeof(*), Δ, a::AbstractMatrix, b::AbstractVecOrMat)
|
|||||||
@back(b, At_mul_B(data(a), Δ))
|
@back(b, At_mul_B(data(a), Δ))
|
||||||
end
|
end
|
||||||
|
|
||||||
|
# Fast path for matrix-vector
|
||||||
|
function back(::typeof(*), Δ::AbstractVector, W::TrackedMatrix, x::AbstractVector)
|
||||||
|
if isleaf(W)
|
||||||
|
W.grad .+= Δ .* data(x).'
|
||||||
|
else
|
||||||
|
back(W, A_mul_Bt(Δ, data(x)))
|
||||||
|
end
|
||||||
|
@back(x, At_mul_B(data(W), Δ))
|
||||||
|
end
|
||||||
|
|
||||||
# NNlib
|
# NNlib
|
||||||
|
|
||||||
import NNlib: softmax, ∇softmax
|
import NNlib: softmax, ∇softmax
|
||||||
|
Loading…
Reference in New Issue
Block a user