diff --git a/src/tracker/lib.jl b/src/tracker/lib.jl index 9f3adc6b..a90eb932 100644 --- a/src/tracker/lib.jl +++ b/src/tracker/lib.jl @@ -57,8 +57,10 @@ back(::typeof(sum), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= Base.maximum(xs::TrackedArray, args...) = maximum(xs.data, args...) Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...) +Base.mean(xs::TrackedArray) = TrackedArray(Call(mean, xs), toarray(xs.data, mean(xs.data))) Base.mean(xs::TrackedArray, region) = TrackedArray(Call(mean, xs, region)) +back(::typeof(mean), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= Δ ./ length(xs.data)) back(::typeof(mean), Δ, xs::TrackedArray, region) = back(xs, similar(xs.data) .= Δ ./ prod(size(xs.data, region...)))