Backpropagation for `maximum` and `minimum`
This commit is contained in:
parent
159ca536ec
commit
cfd29b9c76
|
@ -156,12 +156,16 @@ Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs))
|
|||
back(::typeof(prod), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= (prod(xs.data, dim...) ./ xs.data) .* Δ)
|
||||
back(::typeof(prod), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= (reshape(.*(circshift.([reshape(xs.data, length(xs.data))], 1:length(xs.data)-1)...), size(xs.data))) .* Δ)
|
||||
|
||||
Base.maximum(xs::TrackedArray, args...) = maximum(xs.data, args...)
|
||||
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
|
||||
|
||||
Base.mean(xs::TrackedArray) = track(mean, xs)
|
||||
Base.mean(xs::TrackedArray, region) = track(mean, xs, region)
|
||||
|
||||
Base.maximum(xs::TrackedArray) = track(maximum, xs)
|
||||
Base.maximum(xs::TrackedArray, region) = track(maximum, xs, region)
|
||||
Base.minimum(xs::TrackedArray) = track(minimum, xs)
|
||||
Base.minimum(xs::TrackedArray, region) = track(minimum, xs, region)
|
||||
|
||||
LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys)
|
||||
LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys)
|
||||
LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
|
||||
|
@ -184,6 +188,31 @@ back(::typeof(mean), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= Δ ./
|
|||
back(::typeof(mean), Δ, xs::TrackedArray, region) =
|
||||
back(xs, similar(xs.data) .= Δ ./ prod(size(xs.data, region...)))
|
||||
|
||||
function back(::typeof(maximum), Δ, xs::TrackedArray)
|
||||
Δ′ = zeros(xs.data)
|
||||
_, i = findmax(xs.data)
|
||||
Δ′[i] = Δ
|
||||
@back(xs, Δ′)
|
||||
end
|
||||
function back(::typeof(maximum), Δ, xs::TrackedArray, region)
|
||||
Δ′ = zeros(xs.data)
|
||||
_, is = findmax(xs.data, region)
|
||||
Δ′[is] = Δ
|
||||
@back(xs, Δ′)
|
||||
end
|
||||
function back(::typeof(minimum), Δ, xs::TrackedArray)
|
||||
Δ′ = zeros(xs.data)
|
||||
_, i = findmin(xs.data)
|
||||
Δ′[i] = Δ
|
||||
@back(xs, Δ′)
|
||||
end
|
||||
function back(::typeof(minimum), Δ, xs::TrackedArray, region)
|
||||
Δ′ = zeros(xs.data)
|
||||
_, is = findmin(xs.data, region)
|
||||
Δ′[is] = Δ
|
||||
@back(xs, Δ′)
|
||||
end
|
||||
|
||||
# BLAS
|
||||
|
||||
Base.diagm(x::TrackedVector) = track(diagm, x)
|
||||
|
|
|
@ -55,6 +55,26 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
|
|||
@test gradtest(x -> mean(x, [1, 2]), rand(2, 3, 4))
|
||||
end
|
||||
|
||||
@testset "maximum" begin
|
||||
@test gradtest(maximum, rand(2, 3))
|
||||
|
||||
@test gradtest(x -> maximum(x, 1), rand(2, 3))
|
||||
@test gradtest(x -> maximum(x, 2), rand(2, 3))
|
||||
@test gradtest(x -> maximum(x, 3), rand(2, 3, 4))
|
||||
|
||||
@test gradtest(x -> maximum(x, [1, 2]), rand(2, 3, 4))
|
||||
end
|
||||
|
||||
@testset "minimum" begin
|
||||
@test gradtest(minimum, rand(2, 3))
|
||||
|
||||
@test gradtest(x -> minimum(x, 1), rand(2, 3))
|
||||
@test gradtest(x -> minimum(x, 2), rand(2, 3))
|
||||
@test gradtest(x -> minimum(x, 3), rand(2, 3, 4))
|
||||
|
||||
@test gradtest(x -> minimum(x, [1, 2]), rand(2, 3, 4))
|
||||
end
|
||||
|
||||
@test gradtest(x -> std(x), rand(5,5))
|
||||
@test gradtest(x -> std(x, 1), rand(5,5))
|
||||
|
||||
|
|
Loading…
Reference in New Issue