From 43af3895b097db21fd14d6f75c4df801d6dda25e Mon Sep 17 00:00:00 2001 From: chengchingwen Date: Wed, 7 Mar 2018 21:01:07 +0800 Subject: [PATCH] change `prod` implementation to avoid small xs --- src/tracker/array.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 8ccbc634..4089f505 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -144,7 +144,7 @@ Base.prod(xs::TrackedArray) = track(prod, xs) Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs)) back(::typeof(prod), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= (prod(xs.data, dim...) ./ xs.data) .* Δ) -back(::typeof(prod), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= (prod(xs.data) ./ xs.data) .* Δ) +back(::typeof(prod), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= (reshape(.*(circshift.([reshape(xs.data, length(xs.data))], 1:length(xs.data)-1)...), size(xs.data))) .* Δ) Base.maximum(xs::TrackedArray, args...) = maximum(xs.data, args...) Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)