diff --git a/src/tracker/array.jl b/src/tracker/array.jl index 89dcfac9..4dfb2c6d 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -156,12 +156,16 @@ Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs)) back(::typeof(prod), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= (prod(xs.data, dim...) ./ xs.data) .* Δ) back(::typeof(prod), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= (reshape(.*(circshift.([reshape(xs.data, length(xs.data))], 1:length(xs.data)-1)...), size(xs.data))) .* Δ) -Base.maximum(xs::TrackedArray, args...) = maximum(xs.data, args...) Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...) Base.mean(xs::TrackedArray) = track(mean, xs) Base.mean(xs::TrackedArray, region) = track(mean, xs, region) +Base.maximum(xs::TrackedArray) = track(maximum, xs) +Base.maximum(xs::TrackedArray, region) = track(maximum, xs, region) +Base.minimum(xs::TrackedArray) = track(minimum, xs) +Base.minimum(xs::TrackedArray, region) = track(minimum, xs, region) + LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys) LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys) LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys) @@ -184,6 +188,31 @@ back(::typeof(mean), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= Δ ./ back(::typeof(mean), Δ, xs::TrackedArray, region) = back(xs, similar(xs.data) .= Δ ./ prod(size(xs.data, region...))) +function back(::typeof(maximum), Δ, xs::TrackedArray) + Δ′ = zeros(xs.data) + _, i = findmax(xs.data) + Δ′[i] = Δ + @back(xs, Δ′) +end +function back(::typeof(maximum), Δ, xs::TrackedArray, region) + Δ′ = zeros(xs.data) + _, is = findmax(xs.data, region) + Δ′[is] = Δ + @back(xs, Δ′) +end +function back(::typeof(minimum), Δ, xs::TrackedArray) + Δ′ = zeros(xs.data) + _, i = findmin(xs.data) + Δ′[i] = Δ + @back(xs, Δ′) +end +function back(::typeof(minimum), Δ, xs::TrackedArray, region) + Δ′ = zeros(xs.data) + _, is = findmin(xs.data, region) + Δ′[is] = Δ + @back(xs, Δ′) +end + # BLAS Base.diagm(x::TrackedVector) = track(diagm, x) diff --git a/test/tracker.jl b/test/tracker.jl index 0f5b6189..12ed02e5 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -55,6 +55,26 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...) @test gradtest(x -> mean(x, [1, 2]), rand(2, 3, 4)) end +@testset "maximum" begin + @test gradtest(maximum, rand(2, 3)) + + @test gradtest(x -> maximum(x, 1), rand(2, 3)) + @test gradtest(x -> maximum(x, 2), rand(2, 3)) + @test gradtest(x -> maximum(x, 3), rand(2, 3, 4)) + + @test gradtest(x -> maximum(x, [1, 2]), rand(2, 3, 4)) +end + +@testset "minimum" begin + @test gradtest(minimum, rand(2, 3)) + + @test gradtest(x -> minimum(x, 1), rand(2, 3)) + @test gradtest(x -> minimum(x, 2), rand(2, 3)) + @test gradtest(x -> minimum(x, 3), rand(2, 3, 4)) + + @test gradtest(x -> minimum(x, [1, 2]), rand(2, 3, 4)) +end + @test gradtest(x -> std(x), rand(5,5)) @test gradtest(x -> std(x, 1), rand(5,5))