Compare commits

...

1 Commits

Author SHA1 Message Date
Keno Fischer d55f742533 Add higher order autodiff for getindex
This fixed higher order autodiff or getindex (and by extension vcat, since
it's backward pass uses getindex). This makes tracker able to handle
the third order derivatives needed to implement [1].

[1] Physics-Informed Generative Adversarial Networks for Stochastic Differential Equations. https://arxiv.org/abs/1811.02033
2019-02-05 23:49:48 -05:00
4 changed files with 69 additions and 13 deletions

View File

@ -1,5 +1,5 @@
init_grad(x) = zero(x)
zero_grad!(x) = zero(x)
init_grad(x) = zero_or_nothing(x)
zero_grad!(x) = zero_or_nothing(x)
zero_grad!(x::AbstractArray) = (x .= 0)
scan(c::Call) = foreach(scan, c.args)
@ -32,6 +32,13 @@ back_(::Call{Missing}, Δ, once) = error("`back!` was already used")
accum!(x, Δ) = x .+ Δ
accum!(x::AbstractArray, Δ) = (x .+= Δ)
# This is correct, because the `.grad` field of tracked is constrained
# by a type parameter constructed from the (concrete) forward value and
# thus for `nothing` to show up here, we're guaranteed that the value
# was `nothing` during the forward pass (and thus we don't care about its
# derivatives).
accum!(x::Nothing, Δ) = x
function back(x::Tracked, Δ, once)
x.isleaf && (x.grad = accum!(x.grad, Δ); return)
ref = x.ref -= 1
@ -126,7 +133,7 @@ accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
function back_(g::Grads, c::Call, Δ)
Δs = c.func(Δ)
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
((Δs isa Tuple || Δs isa TrackedTuple) && length(Δs) >= length(c.args)) ||
error("Gradient is not a tuple of length $(length(c.args))")
foreach((x, Δ) -> back(g, x, Δ), c.args, Δs)
end

View File

@ -95,12 +95,27 @@ end
Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...)
@grad function getindex(xs::AbstractArray, i...)
data(xs)[i...], function (Δ)
Δ′ = zero(xs)
Δ′[i...] = data(Δ)
(nobacksies(:getindex, Δ′), map(_->nothing, i)...)
struct ∇getindex{T,S}
xs::T
i::S
end
function (g::∇getindex)(Δ)
Δ′ = zero(g.xs)
Δ′[g.i...] = Δ
(Δ′, map(_->nothing, g.i)...)
end
(g::∇getindex)(Δ::TrackedArray) = track(g, Δ)
@grad function (g::∇getindex)(Δ)
z, back = g(data(Δ)), function(Δ′′)
(Δ′′[1][g.i...],)
end
z, back
end
@grad function getindex(xs::AbstractArray, i...)
data(xs)[i...], ∇getindex(xs, i)
end
Base.view(x::TrackedArray, inds...) = track(Base.view, x, inds...)
@ -414,10 +429,15 @@ using ForwardDiff: Dual, partials, value
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
unbroadcast(x::AbstractArray, Δ) =
size(x) == size(Δ) ? Δ :
length(x) == length(Δ) ? trim(x, Δ) :
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
function unbroadcast(x::AbstractArray, Δ)
if size(x) == size(Δ)
return Δ
elseif length(x) == length(Δ)
return trim(x, Δ)
else
return trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
end
end
unbroadcast(x::Number, Δ) = sum(Δ)
unbroadcast(x::Base.RefValue, _) = nothing

View File

@ -122,7 +122,9 @@ accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ)
init_grad(x::Tuple) = init_grad.(x)
zero_grad!(x::Tuple) = zero_grad!.(x)
track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f, zero.(xs)))
zero_or_nothing(x) = zero(x)
zero_or_nothing(x::Nothing) = nothing
track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f, zero_or_nothing.(xs)))
function Base.show(io::IO, xs::TrackedTuple)
show(io, data(xs))
@ -132,11 +134,18 @@ end
Base.length(x::TrackedTuple) = length(data(x))
Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i)
Base.iterate(xs::TrackedTuple) = track(iterate, xs)
Base.iterate(xs::TrackedTuple, i::Integer) = track(iterate, xs, i)
Base.iterate(xs::TrackedTuple, i::TrackedReal) = iterate(xs, data(i))
@grad function getindex(xs::TrackedTuple, i)
data(xs)[i], Δ -> (ntuple(j -> i == j ? Δ : 0, length(xs)), nothing)
end
@grad function iterate(xs::TrackedTuple, i=1)
(data(xs)[i], i+1), Δ -> (ntuple(j -> i == j ? Δ[1] : 0, length(xs)), nothing)
end
# Array collection
function collect(xs)

View File

@ -5,6 +5,7 @@ using NNlib: conv, depthwiseconv
using Printf: @sprintf
using LinearAlgebra: diagm, dot, LowerTriangular, norm
using Statistics: mean, std
using ForwardDiff
using Random
# using StatsBase
@ -36,6 +37,25 @@ gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
@testset "indexing & slicing" begin
gradtest(x->view(x, 1:2, 1:2), rand(4, 4))
# Nested AD for getindex
grad_tracker = gradient([1.0, 2.0, 3.0]; nest=true) do x
sum(gradient(x; nest=true) do x
sum(gradient(x; nest=true) do x
sum(x[1:2])^4
end[1])
end[1])
end[1]
# We compare to ForwardDiff, since the high order derivative is not
# numerically stable under finite differencing.
grad_forward = ForwardDiff.gradient([1.0, 2.0, 3.0]) do x
sum(ForwardDiff.gradient(x) do x
sum(ForwardDiff.gradient(x) do x
sum(x[1:2])^4
end)
end)
end
@test grad_tracker grad_forward [288.0, 288.0, 0.0]
end
function promotiontest(f, A, B, C)