From 86d782a5ce8c30cd573ee5bc929af54d89276114 Mon Sep 17 00:00:00 2001 From: chengchingwen Date: Tue, 6 Mar 2018 18:01:19 +0800 Subject: [PATCH] implement `prod` for `TrackedArray` --- src/tracker/array.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index eb5d03c7..2b62883b 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -139,6 +139,13 @@ Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs)) back(::typeof(sum), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= Δ) +Base.prod(xs::TrackedArray, dim) = track(prod, xs, dim) +Base.prod(xs::TrackedArray) = track(prod, xs) +Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs)) + +back(::typeof(prod), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= (prod(xs.data, dim) ./ xs.data) .* Δ) +back(::typeof(prod), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= (prod(xs.data) ./ xs.data) .* Δ) + Base.maximum(xs::TrackedArray, args...) = maximum(xs.data, args...) Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)