add scalar mean
This commit is contained in:
parent
1186170542
commit
3d8b7250ae
@ -57,8 +57,10 @@ back(::typeof(sum), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .=
|
||||
Base.maximum(xs::TrackedArray, args...) = maximum(xs.data, args...)
|
||||
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
|
||||
|
||||
Base.mean(xs::TrackedArray) = TrackedArray(Call(mean, xs), toarray(xs.data, mean(xs.data)))
|
||||
Base.mean(xs::TrackedArray, region) = TrackedArray(Call(mean, xs, region))
|
||||
|
||||
back(::typeof(mean), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= Δ ./ length(xs.data))
|
||||
back(::typeof(mean), Δ, xs::TrackedArray, region) =
|
||||
back(xs, similar(xs.data) .= Δ ./ prod(size(xs.data, region...)))
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user