scalar sum
This commit is contained in:
parent
ef681f16ea
commit
97af9db181
@ -1,15 +1,36 @@
|
|||||||
import Base: *
|
import Base: *
|
||||||
|
|
||||||
Base.getindex(xs::TrackedArray, i...) = TrackedArray(Call(getindex, xs, i...))
|
toarray(xs::AbstractArray, ys::AbstractArray) = ys
|
||||||
|
|
||||||
|
function toarray(xs::AbstractArray, y)
|
||||||
|
y′ = similar(xs, typeof(y), ())
|
||||||
|
y′[] = y
|
||||||
|
return y′
|
||||||
|
end
|
||||||
|
|
||||||
|
Base.getindex(xs::TrackedArray, i...) =
|
||||||
|
TrackedArray(Call(getindex, xs, i...), toarray(xs.x, xs.x[i...]))
|
||||||
|
|
||||||
function back!(::typeof(getindex), Δ, xs::TrackedArray, i...)
|
function back!(::typeof(getindex), Δ, xs::TrackedArray, i...)
|
||||||
Δ′ = zeros(xs)
|
Δ′ = zeros(xs.x)
|
||||||
Δ′[i...] = Δ
|
Δ′[i...] = Δ
|
||||||
@back!(xs, Δ′)
|
@back!(xs, Δ′)
|
||||||
end
|
end
|
||||||
|
|
||||||
Base.:-(xs::TrackedArray) = TrackedArray(Call(-, xs))
|
Base.:-(xs::TrackedArray) = TrackedArray(Call(-, xs))
|
||||||
|
|
||||||
|
# Reductions
|
||||||
|
|
||||||
|
Base.sum(xs::TrackedArray) = TrackedArray(Call(sum, xs), toarray(xs.x, sum(xs.x)))
|
||||||
|
|
||||||
|
function back!(::typeof(sum), Δ, xs::TrackedArray)
|
||||||
|
Δ′ = similar(xs.x)
|
||||||
|
Δ′ .= Δ
|
||||||
|
back!(xs, Δ′)
|
||||||
|
end
|
||||||
|
|
||||||
|
# BLAS
|
||||||
|
|
||||||
a::TrackedMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b))
|
a::TrackedMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b))
|
||||||
a::TrackedMatrix * b::AbstractMatrix = TrackedArray(Call(*, a, b))
|
a::TrackedMatrix * b::AbstractMatrix = TrackedArray(Call(*, a, b))
|
||||||
a::AbstractMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b))
|
a::AbstractMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b))
|
||||||
|
Loading…
Reference in New Issue
Block a user