Merge pull request #99 from iblis17/mean
TrackedArray: implement `mean`
This commit is contained in:
commit
1186170542
@ -57,6 +57,11 @@ back(::typeof(sum), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .=
|
|||||||
Base.maximum(xs::TrackedArray, args...) = maximum(xs.data, args...)
|
Base.maximum(xs::TrackedArray, args...) = maximum(xs.data, args...)
|
||||||
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
|
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
|
||||||
|
|
||||||
|
Base.mean(xs::TrackedArray, region) = TrackedArray(Call(mean, xs, region))
|
||||||
|
|
||||||
|
back(::typeof(mean), Δ, xs::TrackedArray, region) =
|
||||||
|
back(xs, similar(xs.data) .= Δ ./ prod(size(xs.data, region...)))
|
||||||
|
|
||||||
# BLAS
|
# BLAS
|
||||||
|
|
||||||
a::TrackedMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b))
|
a::TrackedMatrix * b::TrackedMatrix = TrackedArray(Call(*, a, b))
|
||||||
|
@ -22,6 +22,16 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
|||||||
@test gradtest(vcat, rand(5), rand(3))
|
@test gradtest(vcat, rand(5), rand(3))
|
||||||
@test gradtest(vcat, rand(2,3), rand(3,3))
|
@test gradtest(vcat, rand(2,3), rand(3,3))
|
||||||
|
|
||||||
|
@testset "mean" begin
|
||||||
|
@test gradtest(mean, rand(2, 3))
|
||||||
|
|
||||||
|
@test gradtest(x -> mean(x, 1), rand(2, 3))
|
||||||
|
@test gradtest(x -> mean(x, 2), rand(2, 3))
|
||||||
|
@test gradtest(x -> mean(x, 3), rand(2, 3, 4))
|
||||||
|
|
||||||
|
@test gradtest(x -> mean(x, [1, 2]), rand(2, 3, 4))
|
||||||
|
end
|
||||||
|
|
||||||
@test gradtest(rand(5)) do x
|
@test gradtest(rand(5)) do x
|
||||||
y = x.^2
|
y = x.^2
|
||||||
2y + x
|
2y + x
|
||||||
|
Loading…
Reference in New Issue
Block a user