Backpropagation for `maximum` and `minimum`

This commit is contained in:
Pontus Stenetorp 2018-04-27 22:14:01 +01:00
parent 159ca536ec
commit cfd29b9c76
2 changed files with 50 additions and 1 deletions

View File

@ -156,12 +156,16 @@ Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs))
back(::typeof(prod), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= (prod(xs.data, dim...) ./ xs.data) .* Δ)
back(::typeof(prod), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= (reshape(.*(circshift.([reshape(xs.data, length(xs.data))], 1:length(xs.data)-1)...), size(xs.data))) .* Δ)
Base.maximum(xs::TrackedArray, args...) = maximum(xs.data, args...)
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
Base.mean(xs::TrackedArray) = track(mean, xs)
Base.mean(xs::TrackedArray, region) = track(mean, xs, region)
Base.maximum(xs::TrackedArray) = track(maximum, xs)
Base.maximum(xs::TrackedArray, region) = track(maximum, xs, region)
Base.minimum(xs::TrackedArray) = track(minimum, xs)
Base.minimum(xs::TrackedArray, region) = track(minimum, xs, region)
LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys)
LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys)
LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
@ -184,6 +188,31 @@ back(::typeof(mean), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= Δ ./
back(::typeof(mean), Δ, xs::TrackedArray, region) =
back(xs, similar(xs.data) .= Δ ./ prod(size(xs.data, region...)))
function back(::typeof(maximum), Δ, xs::TrackedArray)
Δ′ = zeros(xs.data)
_, i = findmax(xs.data)
Δ′[i] = Δ
@back(xs, Δ′)
end
function back(::typeof(maximum), Δ, xs::TrackedArray, region)
Δ′ = zeros(xs.data)
_, is = findmax(xs.data, region)
Δ′[is] = Δ
@back(xs, Δ′)
end
function back(::typeof(minimum), Δ, xs::TrackedArray)
Δ′ = zeros(xs.data)
_, i = findmin(xs.data)
Δ′[i] = Δ
@back(xs, Δ′)
end
function back(::typeof(minimum), Δ, xs::TrackedArray, region)
Δ′ = zeros(xs.data)
_, is = findmin(xs.data, region)
Δ′[is] = Δ
@back(xs, Δ′)
end
# BLAS
Base.diagm(x::TrackedVector) = track(diagm, x)

View File

@ -55,6 +55,26 @@ gradtest(f, dims...) = gradtest(f, rand.(dims)...)
@test gradtest(x -> mean(x, [1, 2]), rand(2, 3, 4))
end
@testset "maximum" begin
@test gradtest(maximum, rand(2, 3))
@test gradtest(x -> maximum(x, 1), rand(2, 3))
@test gradtest(x -> maximum(x, 2), rand(2, 3))
@test gradtest(x -> maximum(x, 3), rand(2, 3, 4))
@test gradtest(x -> maximum(x, [1, 2]), rand(2, 3, 4))
end
@testset "minimum" begin
@test gradtest(minimum, rand(2, 3))
@test gradtest(x -> minimum(x, 1), rand(2, 3))
@test gradtest(x -> minimum(x, 2), rand(2, 3))
@test gradtest(x -> minimum(x, 3), rand(2, 3, 4))
@test gradtest(x -> minimum(x, [1, 2]), rand(2, 3, 4))
end
@test gradtest(x -> std(x), rand(5,5))
@test gradtest(x -> std(x, 1), rand(5,5))