diff --git a/src/Tracker/lib.jl b/src/Tracker/lib.jl index 3b2195f3..6003ebd0 100644 --- a/src/Tracker/lib.jl +++ b/src/Tracker/lib.jl @@ -1,12 +1,7 @@ import Base: * toarray(xs::AbstractArray, ys::AbstractArray) = ys - -function toarray(xs::AbstractArray, y) - y′ = similar(xs, typeof(y), ()) - y′[] = y - return y′ -end +toarray(xs::AbstractArray, y) = similar(xs, typeof(y), ()) .= y Base.getindex(xs::TrackedArray, i...) = TrackedArray(Call(getindex, xs, i...), toarray(xs.x, xs.x[i...])) @@ -21,13 +16,10 @@ Base.:-(xs::TrackedArray) = TrackedArray(Call(-, xs)) # Reductions +Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim)) Base.sum(xs::TrackedArray) = TrackedArray(Call(sum, xs), toarray(xs.x, sum(xs.x))) -function back!(::typeof(sum), Δ, xs::TrackedArray) - Δ′ = similar(xs.x) - Δ′ .= Δ - back!(xs, Δ′) -end +back!(::typeof(sum), Δ, xs::TrackedArray, dim...) = back!(xs, similar(xs.x) .= Δ) # BLAS