Add higher order autodiff for getindex
This fixed higher order autodiff or getindex (and by extension vcat, since it's backward pass uses getindex). This makes tracker able to handle the third order derivatives needed to implement [1]. [1] Physics-Informed Generative Adversarial Networks for Stochastic Differential Equations. https://arxiv.org/abs/1811.02033
This commit is contained in:
parent
0469394715
commit
d55f742533
|
@ -1,5 +1,5 @@
|
|||
init_grad(x) = zero(x)
|
||||
zero_grad!(x) = zero(x)
|
||||
init_grad(x) = zero_or_nothing(x)
|
||||
zero_grad!(x) = zero_or_nothing(x)
|
||||
zero_grad!(x::AbstractArray) = (x .= 0)
|
||||
|
||||
scan(c::Call) = foreach(scan, c.args)
|
||||
|
@ -32,6 +32,13 @@ back_(::Call{Missing}, Δ, once) = error("`back!` was already used")
|
|||
accum!(x, Δ) = x .+ Δ
|
||||
accum!(x::AbstractArray, Δ) = (x .+= Δ)
|
||||
|
||||
# This is correct, because the `.grad` field of tracked is constrained
|
||||
# by a type parameter constructed from the (concrete) forward value and
|
||||
# thus for `nothing` to show up here, we're guaranteed that the value
|
||||
# was `nothing` during the forward pass (and thus we don't care about its
|
||||
# derivatives).
|
||||
accum!(x::Nothing, Δ) = x
|
||||
|
||||
function back(x::Tracked, Δ, once)
|
||||
x.isleaf && (x.grad = accum!(x.grad, Δ); return)
|
||||
ref = x.ref -= 1
|
||||
|
@ -126,7 +133,7 @@ accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
|
|||
|
||||
function back_(g::Grads, c::Call, Δ)
|
||||
Δs = c.func(Δ)
|
||||
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
|
||||
((Δs isa Tuple || Δs isa TrackedTuple) && length(Δs) >= length(c.args)) ||
|
||||
error("Gradient is not a tuple of length $(length(c.args))")
|
||||
foreach((x, Δ) -> back(g, x, Δ), c.args, Δs)
|
||||
end
|
||||
|
|
|
@ -95,12 +95,27 @@ end
|
|||
|
||||
Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...)
|
||||
|
||||
@grad function getindex(xs::AbstractArray, i...)
|
||||
data(xs)[i...], function (Δ)
|
||||
Δ′ = zero(xs)
|
||||
Δ′[i...] = data(Δ)
|
||||
(nobacksies(:getindex, Δ′), map(_->nothing, i)...)
|
||||
struct ∇getindex{T,S}
|
||||
xs::T
|
||||
i::S
|
||||
end
|
||||
|
||||
function (g::∇getindex)(Δ)
|
||||
Δ′ = zero(g.xs)
|
||||
Δ′[g.i...] = Δ
|
||||
(Δ′, map(_->nothing, g.i)...)
|
||||
end
|
||||
(g::∇getindex)(Δ::TrackedArray) = track(g, Δ)
|
||||
|
||||
@grad function (g::∇getindex)(Δ)
|
||||
z, back = g(data(Δ)), function(Δ′′)
|
||||
(Δ′′[1][g.i...],)
|
||||
end
|
||||
z, back
|
||||
end
|
||||
|
||||
@grad function getindex(xs::AbstractArray, i...)
|
||||
data(xs)[i...], ∇getindex(xs, i)
|
||||
end
|
||||
|
||||
Base.view(x::TrackedArray, inds...) = track(Base.view, x, inds...)
|
||||
|
@ -414,10 +429,15 @@ using ForwardDiff: Dual, partials, value
|
|||
|
||||
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
|
||||
|
||||
unbroadcast(x::AbstractArray, Δ) =
|
||||
size(x) == size(Δ) ? Δ :
|
||||
length(x) == length(Δ) ? trim(x, Δ) :
|
||||
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
|
||||
function unbroadcast(x::AbstractArray, Δ)
|
||||
if size(x) == size(Δ)
|
||||
return Δ
|
||||
elseif length(x) == length(Δ)
|
||||
return trim(x, Δ)
|
||||
else
|
||||
return trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
|
||||
end
|
||||
end
|
||||
|
||||
unbroadcast(x::Number, Δ) = sum(Δ)
|
||||
unbroadcast(x::Base.RefValue, _) = nothing
|
||||
|
|
|
@ -122,7 +122,9 @@ accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ)
|
|||
init_grad(x::Tuple) = init_grad.(x)
|
||||
zero_grad!(x::Tuple) = zero_grad!.(x)
|
||||
|
||||
track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f, zero.(xs)))
|
||||
zero_or_nothing(x) = zero(x)
|
||||
zero_or_nothing(x::Nothing) = nothing
|
||||
track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f, zero_or_nothing.(xs)))
|
||||
|
||||
function Base.show(io::IO, xs::TrackedTuple)
|
||||
show(io, data(xs))
|
||||
|
@ -132,11 +134,18 @@ end
|
|||
Base.length(x::TrackedTuple) = length(data(x))
|
||||
|
||||
Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i)
|
||||
Base.iterate(xs::TrackedTuple) = track(iterate, xs)
|
||||
Base.iterate(xs::TrackedTuple, i::Integer) = track(iterate, xs, i)
|
||||
Base.iterate(xs::TrackedTuple, i::TrackedReal) = iterate(xs, data(i))
|
||||
|
||||
@grad function getindex(xs::TrackedTuple, i)
|
||||
data(xs)[i], Δ -> (ntuple(j -> i == j ? Δ : 0, length(xs)), nothing)
|
||||
end
|
||||
|
||||
@grad function iterate(xs::TrackedTuple, i=1)
|
||||
(data(xs)[i], i+1), Δ -> (ntuple(j -> i == j ? Δ[1] : 0, length(xs)), nothing)
|
||||
end
|
||||
|
||||
# Array collection
|
||||
|
||||
function collect(xs)
|
||||
|
|
|
@ -5,6 +5,7 @@ using NNlib: conv, depthwiseconv
|
|||
using Printf: @sprintf
|
||||
using LinearAlgebra: diagm, dot, LowerTriangular, norm
|
||||
using Statistics: mean, std
|
||||
using ForwardDiff
|
||||
using Random
|
||||
# using StatsBase
|
||||
|
||||
|
@ -36,6 +37,25 @@ gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
|
|||
|
||||
@testset "indexing & slicing" begin
|
||||
gradtest(x->view(x, 1:2, 1:2), rand(4, 4))
|
||||
|
||||
# Nested AD for getindex
|
||||
grad_tracker = gradient([1.0, 2.0, 3.0]; nest=true) do x
|
||||
sum(gradient(x; nest=true) do x
|
||||
sum(gradient(x; nest=true) do x
|
||||
sum(x[1:2])^4
|
||||
end[1])
|
||||
end[1])
|
||||
end[1]
|
||||
# We compare to ForwardDiff, since the high order derivative is not
|
||||
# numerically stable under finite differencing.
|
||||
grad_forward = ForwardDiff.gradient([1.0, 2.0, 3.0]) do x
|
||||
sum(ForwardDiff.gradient(x) do x
|
||||
sum(ForwardDiff.gradient(x) do x
|
||||
sum(x[1:2])^4
|
||||
end)
|
||||
end)
|
||||
end
|
||||
@test grad_tracker ≈ grad_forward ≈ [288.0, 288.0, 0.0]
|
||||
end
|
||||
|
||||
function promotiontest(f, A, B, C)
|
||||
|
|
Loading…
Reference in New Issue