diff --git a/src/tracker/array.jl b/src/tracker/array.jl index eb5d03c7..2b62883b 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -139,6 +139,13 @@ Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs)) back(::typeof(sum), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= Δ) +Base.prod(xs::TrackedArray, dim) = track(prod, xs, dim) +Base.prod(xs::TrackedArray) = track(prod, xs) +Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs)) + +back(::typeof(prod), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= (prod(xs.data, dim) ./ xs.data) .* Δ) +back(::typeof(prod), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= (prod(xs.data) ./ xs.data) .* Δ) + Base.maximum(xs::TrackedArray, args...) = maximum(xs.data, args...) Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)