diff --git a/src/Tracker/lib.jl b/src/Tracker/lib.jl index 41b7e676..3b2195f3 100644 --- a/src/Tracker/lib.jl +++ b/src/Tracker/lib.jl @@ -1,15 +1,36 @@ import Base: * -Base.getindex(xs::TrackedArray, i...) = TrackedArray(Call(getindex, xs, i...)) +toarray(xs::AbstractArray, ys::AbstractArray) = ys + +function toarray(xs::AbstractArray, y) + y′ = similar(xs, typeof(y), ()) + y′[] = y + return y′ +end + +Base.getindex(xs::TrackedArray, i...) = + TrackedArray(Call(getindex, xs, i...), toarray(xs.x, xs.x[i...])) function back!(::typeof(getindex), Δ, xs::TrackedArray, i...) - Δ′ = zeros(xs) + Δ′ = zeros(xs.x) Δ′[i...] = Δ @back!(xs, Δ′) end Base.:-(xs::TrackedArray) = TrackedArray(Call(-, xs)) +# Reductions + +Base.sum(xs::TrackedArray) = TrackedArray(Call(sum, xs), toarray(xs.x, sum(xs.x))) + +function back!(::typeof(sum), Δ, xs::TrackedArray) + Δ′ = similar(xs.x) + Δ′ .= Δ + back!(xs, Δ′) +end + +# BLAS + a::TrackedMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b)) a::TrackedMatrix * b::AbstractMatrix = TrackedArray(Call(*, a, b)) a::AbstractMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b))