diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index d6fa6f35..5e26a051 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -38,6 +38,8 @@ TrackedArray(c::Call) = TrackedArray(c, c()) TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x)) +isleaf(x::TrackedArray) = x.f == Call(nothing) + param(xs) = TrackedArray(AbstractFloat.(xs)) param(xs::Real) = param(fill(xs)) diff --git a/src/tracker/lib.jl b/src/tracker/lib.jl index a90eb932..2ee5d659 100644 --- a/src/tracker/lib.jl +++ b/src/tracker/lib.jl @@ -79,6 +79,16 @@ function back(::typeof(*), Δ, a::AbstractMatrix, b::AbstractVecOrMat) @back(b, At_mul_B(data(a), Δ)) end +# Fast path for matrix-vector +function back(::typeof(*), Δ::AbstractVector, W::TrackedMatrix, x::AbstractVector) + if isleaf(W) + W.grad .+= Δ .* data(x).' + else + back(W, A_mul_Bt(Δ, data(x))) + end + @back(x, At_mul_B(data(W), Δ)) +end + # NNlib import NNlib: softmax, ∇softmax