diff --git a/src/tracker/lib.jl b/src/tracker/lib.jl index 254be8dc..9f3adc6b 100644 --- a/src/tracker/lib.jl +++ b/src/tracker/lib.jl @@ -57,6 +57,11 @@ back(::typeof(sum), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= Base.maximum(xs::TrackedArray, args...) = maximum(xs.data, args...) Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...) +Base.mean(xs::TrackedArray, region) = TrackedArray(Call(mean, xs, region)) + +back(::typeof(mean), Δ, xs::TrackedArray, region) = + back(xs, similar(xs.data) .= Δ ./ prod(size(xs.data, region...))) + # BLAS a::TrackedMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b)) diff --git a/test/tracker.jl b/test/tracker.jl index 2a20338e..52a73a07 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -22,6 +22,16 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...) @test gradtest(vcat, rand(5), rand(3)) @test gradtest(vcat, rand(2,3), rand(3,3)) +@testset "mean" begin + @test gradtest(mean, rand(2, 3)) + + @test gradtest(x -> mean(x, 1), rand(2, 3)) + @test gradtest(x -> mean(x, 2), rand(2, 3)) + @test gradtest(x -> mean(x, 3), rand(2, 3, 4)) + + @test gradtest(x -> mean(x, [1, 2]), rand(2, 3, 4)) +end + @test gradtest(rand(5)) do x y = x.^2 2y + x