scalar sum

This commit is contained in:
Mike J Innes 2017-08-22 12:24:08 +01:00
parent ef681f16ea
commit 97af9db181

View File

@ -1,15 +1,36 @@
import Base: *
Base.getindex(xs::TrackedArray, i...) = TrackedArray(Call(getindex, xs, i...))
toarray(xs::AbstractArray, ys::AbstractArray) = ys
function toarray(xs::AbstractArray, y)
y = similar(xs, typeof(y), ())
y[] = y
return y
end
Base.getindex(xs::TrackedArray, i...) =
TrackedArray(Call(getindex, xs, i...), toarray(xs.x, xs.x[i...]))
function back!(::typeof(getindex), Δ, xs::TrackedArray, i...)
Δ′ = zeros(xs)
Δ′ = zeros(xs.x)
Δ′[i...] = Δ
@back!(xs, Δ′)
end
Base.:-(xs::TrackedArray) = TrackedArray(Call(-, xs))
# Reductions
Base.sum(xs::TrackedArray) = TrackedArray(Call(sum, xs), toarray(xs.x, sum(xs.x)))
function back!(::typeof(sum), Δ, xs::TrackedArray)
Δ′ = similar(xs.x)
Δ′ .= Δ
back!(xs, Δ′)
end
# BLAS
a::TrackedMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b))
a::TrackedMatrix * b::AbstractMatrix = TrackedArray(Call(*, a, b))
a::AbstractMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b))