Add higher order autodiff for getindex

This fixed higher order autodiff or getindex (and by extension vcat, since
it's backward pass uses getindex). This makes tracker able to handle
the third order derivatives needed to implement [1].

[1] Physics-Informed Generative Adversarial Networks for Stochastic Differential Equations. https://arxiv.org/abs/1811.02033
This commit is contained in:
Keno Fischer 2019-02-04 19:38:02 -05:00
parent 0469394715
commit d55f742533
4 changed files with 69 additions and 13 deletions

View File

@ -1,5 +1,5 @@
init_grad(x) = zero(x)
zero_grad!(x) = zero(x)
init_grad(x) = zero_or_nothing(x)
zero_grad!(x) = zero_or_nothing(x)
zero_grad!(x::AbstractArray) = (x .= 0)
scan(c::Call) = foreach(scan, c.args)
@ -32,6 +32,13 @@ back_(::Call{Missing}, Δ, once) = error("`back!` was already used")
accum!(x, Δ) = x .+ Δ
accum!(x::AbstractArray, Δ) = (x .+= Δ)
# This is correct, because the `.grad` field of tracked is constrained
# by a type parameter constructed from the (concrete) forward value and
# thus for `nothing` to show up here, we're guaranteed that the value
# was `nothing` during the forward pass (and thus we don't care about its
# derivatives).
accum!(x::Nothing, Δ) = x
function back(x::Tracked, Δ, once)
x.isleaf && (x.grad = accum!(x.grad, Δ); return)
ref = x.ref -= 1
@ -126,7 +133,7 @@ accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
function back_(g::Grads, c::Call, Δ)
Δs = c.func(Δ)
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
((Δs isa Tuple || Δs isa TrackedTuple) && length(Δs) >= length(c.args)) ||
error("Gradient is not a tuple of length $(length(c.args))")
foreach((x, Δ) -> back(g, x, Δ), c.args, Δs)
end

View File

@ -95,12 +95,27 @@ end
Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...)
@grad function getindex(xs::AbstractArray, i...)
data(xs)[i...], function (Δ)
Δ′ = zero(xs)
Δ′[i...] = data(Δ)
(nobacksies(:getindex, Δ′), map(_->nothing, i)...)
struct ∇getindex{T,S}
xs::T
i::S
end
function (g::∇getindex)(Δ)
Δ′ = zero(g.xs)
Δ′[g.i...] = Δ
(Δ′, map(_->nothing, g.i)...)
end
(g::∇getindex)(Δ::TrackedArray) = track(g, Δ)
@grad function (g::∇getindex)(Δ)
z, back = g(data(Δ)), function(Δ′′)
(Δ′′[1][g.i...],)
end
z, back
end
@grad function getindex(xs::AbstractArray, i...)
data(xs)[i...], ∇getindex(xs, i)
end
Base.view(x::TrackedArray, inds...) = track(Base.view, x, inds...)
@ -414,10 +429,15 @@ using ForwardDiff: Dual, partials, value
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
unbroadcast(x::AbstractArray, Δ) =
size(x) == size(Δ) ? Δ :
length(x) == length(Δ) ? trim(x, Δ) :
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
function unbroadcast(x::AbstractArray, Δ)
if size(x) == size(Δ)
return Δ
elseif length(x) == length(Δ)
return trim(x, Δ)
else
return trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
end
end
unbroadcast(x::Number, Δ) = sum(Δ)
unbroadcast(x::Base.RefValue, _) = nothing

View File

@ -122,7 +122,9 @@ accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ)
init_grad(x::Tuple) = init_grad.(x)
zero_grad!(x::Tuple) = zero_grad!.(x)
track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f, zero.(xs)))
zero_or_nothing(x) = zero(x)
zero_or_nothing(x::Nothing) = nothing
track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f, zero_or_nothing.(xs)))
function Base.show(io::IO, xs::TrackedTuple)
show(io, data(xs))
@ -132,11 +134,18 @@ end
Base.length(x::TrackedTuple) = length(data(x))
Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i)
Base.iterate(xs::TrackedTuple) = track(iterate, xs)
Base.iterate(xs::TrackedTuple, i::Integer) = track(iterate, xs, i)
Base.iterate(xs::TrackedTuple, i::TrackedReal) = iterate(xs, data(i))
@grad function getindex(xs::TrackedTuple, i)
data(xs)[i], Δ -> (ntuple(j -> i == j ? Δ : 0, length(xs)), nothing)
end
@grad function iterate(xs::TrackedTuple, i=1)
(data(xs)[i], i+1), Δ -> (ntuple(j -> i == j ? Δ[1] : 0, length(xs)), nothing)
end
# Array collection
function collect(xs)

View File

@ -5,6 +5,7 @@ using NNlib: conv, depthwiseconv
using Printf: @sprintf
using LinearAlgebra: diagm, dot, LowerTriangular, norm
using Statistics: mean, std
using ForwardDiff
using Random
# using StatsBase
@ -36,6 +37,25 @@ gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
@testset "indexing & slicing" begin
gradtest(x->view(x, 1:2, 1:2), rand(4, 4))
# Nested AD for getindex
grad_tracker = gradient([1.0, 2.0, 3.0]; nest=true) do x
sum(gradient(x; nest=true) do x
sum(gradient(x; nest=true) do x
sum(x[1:2])^4
end[1])
end[1])
end[1]
# We compare to ForwardDiff, since the high order derivative is not
# numerically stable under finite differencing.
grad_forward = ForwardDiff.gradient([1.0, 2.0, 3.0]) do x
sum(ForwardDiff.gradient(x) do x
sum(ForwardDiff.gradient(x) do x
sum(x[1:2])^4
end)
end)
end
@test grad_tracker grad_forward [288.0, 288.0, 0.0]
end
function promotiontest(f, A, B, C)