Add higher order autodiff for getindex

This fixed higher order autodiff or getindex (and by extension vcat, since
it's backward pass uses getindex). This makes tracker able to handle
the third order derivatives needed to implement [1].

[1] Physics-Informed Generative Adversarial Networks for Stochastic Differential Equations. https://arxiv.org/abs/1811.02033
This commit is contained in:
Keno Fischer 2019-02-04 19:38:02 -05:00
parent 0469394715
commit d55f742533
4 changed files with 69 additions and 13 deletions

View File

@ -1,5 +1,5 @@
init_grad(x) = zero(x) init_grad(x) = zero_or_nothing(x)
zero_grad!(x) = zero(x) zero_grad!(x) = zero_or_nothing(x)
zero_grad!(x::AbstractArray) = (x .= 0) zero_grad!(x::AbstractArray) = (x .= 0)
scan(c::Call) = foreach(scan, c.args) scan(c::Call) = foreach(scan, c.args)
@ -32,6 +32,13 @@ back_(::Call{Missing}, Δ, once) = error("`back!` was already used")
accum!(x, Δ) = x .+ Δ accum!(x, Δ) = x .+ Δ
accum!(x::AbstractArray, Δ) = (x .+= Δ) accum!(x::AbstractArray, Δ) = (x .+= Δ)
# This is correct, because the `.grad` field of tracked is constrained
# by a type parameter constructed from the (concrete) forward value and
# thus for `nothing` to show up here, we're guaranteed that the value
# was `nothing` during the forward pass (and thus we don't care about its
# derivatives).
accum!(x::Nothing, Δ) = x
function back(x::Tracked, Δ, once) function back(x::Tracked, Δ, once)
x.isleaf && (x.grad = accum!(x.grad, Δ); return) x.isleaf && (x.grad = accum!(x.grad, Δ); return)
ref = x.ref -= 1 ref = x.ref -= 1
@ -126,7 +133,7 @@ accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
function back_(g::Grads, c::Call, Δ) function back_(g::Grads, c::Call, Δ)
Δs = c.func(Δ) Δs = c.func(Δ)
(Δs isa Tuple && length(Δs) >= length(c.args)) || ((Δs isa Tuple || Δs isa TrackedTuple) && length(Δs) >= length(c.args)) ||
error("Gradient is not a tuple of length $(length(c.args))") error("Gradient is not a tuple of length $(length(c.args))")
foreach((x, Δ) -> back(g, x, Δ), c.args, Δs) foreach((x, Δ) -> back(g, x, Δ), c.args, Δs)
end end

View File

@ -95,12 +95,27 @@ end
Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...) Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...)
@grad function getindex(xs::AbstractArray, i...) struct ∇getindex{T,S}
data(xs)[i...], function (Δ) xs::T
Δ′ = zero(xs) i::S
Δ′[i...] = data(Δ) end
(nobacksies(:getindex, Δ′), map(_->nothing, i)...)
function (g::∇getindex)(Δ)
Δ′ = zero(g.xs)
Δ′[g.i...] = Δ
(Δ′, map(_->nothing, g.i)...)
end
(g::∇getindex)(Δ::TrackedArray) = track(g, Δ)
@grad function (g::∇getindex)(Δ)
z, back = g(data(Δ)), function(Δ′′)
(Δ′′[1][g.i...],)
end end
z, back
end
@grad function getindex(xs::AbstractArray, i...)
data(xs)[i...], ∇getindex(xs, i)
end end
Base.view(x::TrackedArray, inds...) = track(Base.view, x, inds...) Base.view(x::TrackedArray, inds...) = track(Base.view, x, inds...)
@ -414,10 +429,15 @@ using ForwardDiff: Dual, partials, value
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x)))) trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
unbroadcast(x::AbstractArray, Δ) = function unbroadcast(x::AbstractArray, Δ)
size(x) == size(Δ) ? Δ : if size(x) == size(Δ)
length(x) == length(Δ) ? trim(x, Δ) : return Δ
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ))))) elseif length(x) == length(Δ)
return trim(x, Δ)
else
return trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
end
end
unbroadcast(x::Number, Δ) = sum(Δ) unbroadcast(x::Number, Δ) = sum(Δ)
unbroadcast(x::Base.RefValue, _) = nothing unbroadcast(x::Base.RefValue, _) = nothing

View File

@ -122,7 +122,9 @@ accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ)
init_grad(x::Tuple) = init_grad.(x) init_grad(x::Tuple) = init_grad.(x)
zero_grad!(x::Tuple) = zero_grad!.(x) zero_grad!(x::Tuple) = zero_grad!.(x)
track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f, zero.(xs))) zero_or_nothing(x) = zero(x)
zero_or_nothing(x::Nothing) = nothing
track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f, zero_or_nothing.(xs)))
function Base.show(io::IO, xs::TrackedTuple) function Base.show(io::IO, xs::TrackedTuple)
show(io, data(xs)) show(io, data(xs))
@ -132,11 +134,18 @@ end
Base.length(x::TrackedTuple) = length(data(x)) Base.length(x::TrackedTuple) = length(data(x))
Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i) Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i)
Base.iterate(xs::TrackedTuple) = track(iterate, xs)
Base.iterate(xs::TrackedTuple, i::Integer) = track(iterate, xs, i)
Base.iterate(xs::TrackedTuple, i::TrackedReal) = iterate(xs, data(i))
@grad function getindex(xs::TrackedTuple, i) @grad function getindex(xs::TrackedTuple, i)
data(xs)[i], Δ -> (ntuple(j -> i == j ? Δ : 0, length(xs)), nothing) data(xs)[i], Δ -> (ntuple(j -> i == j ? Δ : 0, length(xs)), nothing)
end end
@grad function iterate(xs::TrackedTuple, i=1)
(data(xs)[i], i+1), Δ -> (ntuple(j -> i == j ? Δ[1] : 0, length(xs)), nothing)
end
# Array collection # Array collection
function collect(xs) function collect(xs)

View File

@ -5,6 +5,7 @@ using NNlib: conv, depthwiseconv
using Printf: @sprintf using Printf: @sprintf
using LinearAlgebra: diagm, dot, LowerTriangular, norm using LinearAlgebra: diagm, dot, LowerTriangular, norm
using Statistics: mean, std using Statistics: mean, std
using ForwardDiff
using Random using Random
# using StatsBase # using StatsBase
@ -36,6 +37,25 @@ gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
@testset "indexing & slicing" begin @testset "indexing & slicing" begin
gradtest(x->view(x, 1:2, 1:2), rand(4, 4)) gradtest(x->view(x, 1:2, 1:2), rand(4, 4))
# Nested AD for getindex
grad_tracker = gradient([1.0, 2.0, 3.0]; nest=true) do x
sum(gradient(x; nest=true) do x
sum(gradient(x; nest=true) do x
sum(x[1:2])^4
end[1])
end[1])
end[1]
# We compare to ForwardDiff, since the high order derivative is not
# numerically stable under finite differencing.
grad_forward = ForwardDiff.gradient([1.0, 2.0, 3.0]) do x
sum(ForwardDiff.gradient(x) do x
sum(ForwardDiff.gradient(x) do x
sum(x[1:2])^4
end)
end)
end
@test grad_tracker grad_forward [288.0, 288.0, 0.0]
end end
function promotiontest(f, A, B, C) function promotiontest(f, A, B, C)