TrackedArray: implement mean

```julia
julia> p
Tracked 2×3 Array{Float64,2}:
 1.0  3.0  5.0
 2.0  4.0  6.0
```

Before
```julia
julia> @benchmark Flux.Tracker.back!(sum($p, 2) ./ size($p, 2), ones(2, 1))
BenchmarkTools.Trial:
  memory estimate:  3.44 KiB
  allocs estimate:  75
  --------------
  minimum time:     20.438 μs (0.00% GC)
  median time:      21.239 μs (0.00% GC)
  mean time:        22.354 μs (1.68% GC)
  maximum time:     3.811 ms (98.51% GC)
  --------------
  samples:          10000
  evals/sample:     1
```

After
```julia
julia> @benchmark Flux.Tracker.back!(mean($p, 2), ones(2, 1))
BenchmarkTools.Trial:
  memory estimate:  1008 bytes
  allocs estimate:  21
  --------------
  minimum time:     5.973 μs (0.00% GC)
  median time:      6.310 μs (0.00% GC)
  mean time:        6.630 μs (1.96% GC)
  maximum time:     680.709 μs (97.28% GC)
  --------------
  samples:          10000
  evals/sample:     6
```
This commit is contained in:
Iblis Lin 2017-10-30 16:21:02 +08:00
parent 4c1b1eb18c
commit c43bda019b
2 changed files with 15 additions and 0 deletions

View File

@ -57,6 +57,11 @@ back(::typeof(sum), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .=
Base.maximum(xs::TrackedArray, args...) = maximum(xs.data, args...)
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
Base.mean(xs::TrackedArray, region) = TrackedArray(Call(mean, xs, region))
back(::typeof(mean), Δ, xs::TrackedArray, region) =
back(xs, similar(xs.data) .= Δ ./ prod(size(xs.data, region...)))
# BLAS
a::TrackedMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b))

View File

@ -22,6 +22,16 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest(vcat, rand(5), rand(3))
@test gradtest(vcat, rand(2,3), rand(3,3))
@testset "mean" begin
@test gradtest(mean, rand(2, 3))
@test gradtest(x -> mean(x, 1), rand(2, 3))
@test gradtest(x -> mean(x, 2), rand(2, 3))
@test gradtest(x -> mean(x, 3), rand(2, 3, 4))
@test gradtest(x -> mean(x, [1, 2]), rand(2, 3, 4))
end
@test gradtest(rand(5)) do x
y = x.^2
2y + x