TrackedArray: implement mean
```julia julia> p Tracked 2×3 Array{Float64,2}: 1.0 3.0 5.0 2.0 4.0 6.0 ``` Before ```julia julia> @benchmark Flux.Tracker.back!(sum($p, 2) ./ size($p, 2), ones(2, 1)) BenchmarkTools.Trial: memory estimate: 3.44 KiB allocs estimate: 75 -------------- minimum time: 20.438 μs (0.00% GC) median time: 21.239 μs (0.00% GC) mean time: 22.354 μs (1.68% GC) maximum time: 3.811 ms (98.51% GC) -------------- samples: 10000 evals/sample: 1 ``` After ```julia julia> @benchmark Flux.Tracker.back!(mean($p, 2), ones(2, 1)) BenchmarkTools.Trial: memory estimate: 1008 bytes allocs estimate: 21 -------------- minimum time: 5.973 μs (0.00% GC) median time: 6.310 μs (0.00% GC) mean time: 6.630 μs (1.96% GC) maximum time: 680.709 μs (97.28% GC) -------------- samples: 10000 evals/sample: 6 ```
This commit is contained in:
parent
4c1b1eb18c
commit
c43bda019b
@ -57,6 +57,11 @@ back(::typeof(sum), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .=
|
||||
Base.maximum(xs::TrackedArray, args...) = maximum(xs.data, args...)
|
||||
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
|
||||
|
||||
Base.mean(xs::TrackedArray, region) = TrackedArray(Call(mean, xs, region))
|
||||
|
||||
back(::typeof(mean), Δ, xs::TrackedArray, region) =
|
||||
back(xs, similar(xs.data) .= Δ ./ prod(size(xs.data, region...)))
|
||||
|
||||
# BLAS
|
||||
|
||||
a::TrackedMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b))
|
||||
|
@ -22,6 +22,16 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
||||
@test gradtest(vcat, rand(5), rand(3))
|
||||
@test gradtest(vcat, rand(2,3), rand(3,3))
|
||||
|
||||
@testset "mean" begin
|
||||
@test gradtest(mean, rand(2, 3))
|
||||
|
||||
@test gradtest(x -> mean(x, 1), rand(2, 3))
|
||||
@test gradtest(x -> mean(x, 2), rand(2, 3))
|
||||
@test gradtest(x -> mean(x, 3), rand(2, 3, 4))
|
||||
|
||||
@test gradtest(x -> mean(x, [1, 2]), rand(2, 3, 4))
|
||||
end
|
||||
|
||||
@test gradtest(rand(5)) do x
|
||||
y = x.^2
|
||||
2y + x
|
||||
|
Loading…
Reference in New Issue
Block a user