implement prod for TrackedArray

This commit is contained in:
chengchingwen 2018-03-06 18:01:19 +08:00
parent eab26be0af
commit 86d782a5ce

View File

@ -139,6 +139,13 @@ Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs))
back(::typeof(sum), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= Δ)
Base.prod(xs::TrackedArray, dim) = track(prod, xs, dim)
Base.prod(xs::TrackedArray) = track(prod, xs)
Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs))
back(::typeof(prod), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= (prod(xs.data, dim) ./ xs.data) .* Δ)
back(::typeof(prod), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= (prod(xs.data) ./ xs.data) .* Δ)
Base.maximum(xs::TrackedArray, args...) = maximum(xs.data, args...)
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)