From d36d7879df2ef22812fc60187e5d39c40296c017 Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Mon, 4 Feb 2019 19:38:02 -0500 Subject: [PATCH] Add higher order autodiff for getindex This fixed higher order autodiff or getindex (and by extension vcat, since it's backward pass uses getindex). This makes tracker able to handle the third order derivatives needed to implement [1]. [1] Physics-Informed Generative Adversarial Networks for Stochastic Differential Equations. https://arxiv.org/abs/1811.02033 --- src/tracker/back.jl | 13 ++++++++++--- src/tracker/lib/array.jl | 38 +++++++++++++++++++++++++++++--------- src/tracker/lib/real.jl | 11 ++++++++++- test/tracker.jl | 20 ++++++++++++++++++++ 4 files changed, 69 insertions(+), 13 deletions(-) diff --git a/src/tracker/back.jl b/src/tracker/back.jl index ef65ecb6..c4ed1d5c 100644 --- a/src/tracker/back.jl +++ b/src/tracker/back.jl @@ -1,5 +1,5 @@ -init_grad(x) = zero(x) -zero_grad!(x) = zero(x) +init_grad(x) = zero_or_nothing(x) +zero_grad!(x) = zero_or_nothing(x) zero_grad!(x::AbstractArray) = (x .= 0) scan(c::Call) = foreach(scan, c.args) @@ -32,6 +32,13 @@ back_(::Call{Missing}, Δ, once) = error("`back!` was already used") accum!(x, Δ) = x .+ Δ accum!(x::AbstractArray, Δ) = (x .+= Δ) +# This is correct, because the `.grad` field of tracked is constrained +# by a type parameter constructed from the (concrete) forward value and +# thus for `nothing` to show up here, we're guaranteed that the value +# was `nothing` during the forward pass (and thus we don't care about its +# derivatives). +accum!(x::Nothing, Δ) = x + function back(x::Tracked, Δ, once) x.isleaf && (x.grad = accum!(x.grad, Δ); return) ref = x.ref -= 1 @@ -126,7 +133,7 @@ accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ function back_(g::Grads, c::Call, Δ) Δs = c.func(Δ) - (Δs isa Tuple && length(Δs) >= length(c.args)) || + ((Δs isa Tuple || Δs isa TrackedTuple) && length(Δs) >= length(c.args)) || error("Gradient is not a tuple of length $(length(c.args))") foreach((x, Δ) -> back(g, x, Δ), c.args, Δs) end diff --git a/src/tracker/lib/array.jl b/src/tracker/lib/array.jl index f8697c88..6a2ab965 100644 --- a/src/tracker/lib/array.jl +++ b/src/tracker/lib/array.jl @@ -95,12 +95,27 @@ end Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...) -@grad function getindex(xs::AbstractArray, i...) - data(xs)[i...], function (Δ) - Δ′ = zero(xs) - Δ′[i...] = data(Δ) - (nobacksies(:getindex, Δ′), map(_->nothing, i)...) +struct ∇getindex{T,S} + xs::T + i::S +end + +function (g::∇getindex)(Δ) + Δ′ = zero(g.xs) + Δ′[g.i...] = Δ + (Δ′, map(_->nothing, g.i)...) +end +(g::∇getindex)(Δ::TrackedArray) = track(g, Δ) + +@grad function (g::∇getindex)(Δ) + z, back = g(data(Δ)), function(Δ′′) + (Δ′′[1][g.i...],) end + z, back +end + +@grad function getindex(xs::AbstractArray, i...) + data(xs)[i...], ∇getindex(xs, i) end Base.view(x::TrackedArray, inds...) = track(Base.view, x, inds...) @@ -424,10 +439,15 @@ using ForwardDiff: Dual, partials, value trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x)))) -unbroadcast(x::AbstractArray, Δ) = - size(x) == size(Δ) ? Δ : - length(x) == length(Δ) ? trim(x, Δ) : - trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ))))) +function unbroadcast(x::AbstractArray, Δ) + if size(x) == size(Δ) + return Δ + elseif length(x) == length(Δ) + return trim(x, Δ) + else + return trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ))))) + end +end unbroadcast(x::Number, Δ) = sum(Δ) unbroadcast(x::Base.RefValue, _) = nothing diff --git a/src/tracker/lib/real.jl b/src/tracker/lib/real.jl index ec57f0d3..1079c634 100644 --- a/src/tracker/lib/real.jl +++ b/src/tracker/lib/real.jl @@ -122,7 +122,9 @@ accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ) init_grad(x::Tuple) = init_grad.(x) zero_grad!(x::Tuple) = zero_grad!.(x) -track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f, zero.(xs))) +zero_or_nothing(x) = zero(x) +zero_or_nothing(x::Nothing) = nothing +track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f, zero_or_nothing.(xs))) function Base.show(io::IO, xs::TrackedTuple) show(io, data(xs)) @@ -132,11 +134,18 @@ end Base.length(x::TrackedTuple) = length(data(x)) Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i) +Base.iterate(xs::TrackedTuple) = track(iterate, xs) +Base.iterate(xs::TrackedTuple, i::Integer) = track(iterate, xs, i) +Base.iterate(xs::TrackedTuple, i::TrackedReal) = iterate(xs, data(i)) @grad function getindex(xs::TrackedTuple, i) data(xs)[i], Δ -> (ntuple(j -> i == j ? Δ : 0, length(xs)), nothing) end +@grad function iterate(xs::TrackedTuple, i=1) + (data(xs)[i], i+1), Δ -> (ntuple(j -> i == j ? Δ[1] : 0, length(xs)), nothing) +end + # Array collection function collect(xs) diff --git a/test/tracker.jl b/test/tracker.jl index 817ba389..5ed61120 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -5,6 +5,7 @@ using NNlib: conv, ∇conv_data, depthwiseconv using Printf: @sprintf using LinearAlgebra: diagm, dot, LowerTriangular, norm using Statistics: mean, std +using ForwardDiff using Random # using StatsBase @@ -36,6 +37,25 @@ gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...) @testset "indexing & slicing" begin gradtest(x->view(x, 1:2, 1:2), rand(4, 4)) + + # Nested AD for getindex + grad_tracker = gradient([1.0, 2.0, 3.0]; nest=true) do x + sum(gradient(x; nest=true) do x + sum(gradient(x; nest=true) do x + sum(x[1:2])^4 + end[1]) + end[1]) + end[1] + # We compare to ForwardDiff, since the high order derivative is not + # numerically stable under finite differencing. + grad_forward = ForwardDiff.gradient([1.0, 2.0, 3.0]) do x + sum(ForwardDiff.gradient(x) do x + sum(ForwardDiff.gradient(x) do x + sum(x[1:2])^4 + end) + end) + end + @test grad_tracker ≈ grad_forward ≈ [288.0, 288.0, 0.0] end function promotiontest(f, A, B, C)