vector sum
This commit is contained in:
parent
97af9db181
commit
cd45df1eca
@ -1,12 +1,7 @@
|
||||
import Base: *
|
||||
|
||||
toarray(xs::AbstractArray, ys::AbstractArray) = ys
|
||||
|
||||
function toarray(xs::AbstractArray, y)
|
||||
y′ = similar(xs, typeof(y), ())
|
||||
y′[] = y
|
||||
return y′
|
||||
end
|
||||
toarray(xs::AbstractArray, y) = similar(xs, typeof(y), ()) .= y
|
||||
|
||||
Base.getindex(xs::TrackedArray, i...) =
|
||||
TrackedArray(Call(getindex, xs, i...), toarray(xs.x, xs.x[i...]))
|
||||
@ -21,13 +16,10 @@ Base.:-(xs::TrackedArray) = TrackedArray(Call(-, xs))
|
||||
|
||||
# Reductions
|
||||
|
||||
Base.sum(xs::TrackedArray, dim) = TrackedArray(Call(sum, xs, dim))
|
||||
Base.sum(xs::TrackedArray) = TrackedArray(Call(sum, xs), toarray(xs.x, sum(xs.x)))
|
||||
|
||||
function back!(::typeof(sum), Δ, xs::TrackedArray)
|
||||
Δ′ = similar(xs.x)
|
||||
Δ′ .= Δ
|
||||
back!(xs, Δ′)
|
||||
end
|
||||
back!(::typeof(sum), Δ, xs::TrackedArray, dim...) = back!(xs, similar(xs.x) .= Δ)
|
||||
|
||||
# BLAS
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user