From cd45df1eca2e66c02049351743d831f524c8a2ae Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 22 Aug 2017 15:12:12 +0100 Subject: [PATCH] vector sum --- src/Tracker/lib.jl | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/src/Tracker/lib.jl b/src/Tracker/lib.jl index 3b2195f3..6003ebd0 100644 --- a/src/Tracker/lib.jl +++ b/src/Tracker/lib.jl @@ -1,12 +1,7 @@ import Base: * toarray(xs::AbstractArray, ys::AbstractArray) = ys - -function toarray(xs::AbstractArray, y) - y′ = similar(xs, typeof(y), ()) - y′[] = y - return y′ -end +toarray(xs::AbstractArray, y) = similar(xs, typeof(y), ()) .= y Base.getindex(xs::TrackedArray, i...) = TrackedArray(Call(getindex, xs, i...), toarray(xs.x, xs.x[i...])) @@ -21,13 +16,10 @@ Base.:-(xs::TrackedArray) = TrackedArray(Call(-, xs)) # Reductions +Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim)) Base.sum(xs::TrackedArray) = TrackedArray(Call(sum, xs), toarray(xs.x, sum(xs.x))) -function back!(::typeof(sum), Δ, xs::TrackedArray) - Δ′ = similar(xs.x) - Δ′ .= Δ - back!(xs, Δ′) -end +back!(::typeof(sum), Δ, xs::TrackedArray, dim...) = back!(xs, similar(xs.x) .= Δ) # BLAS