diff --git a/src/tracker/back.jl b/src/tracker/back.jl index ef65ecb6..c4ed1d5c 100644 --- a/src/tracker/back.jl +++ b/src/tracker/back.jl @@ -1,5 +1,5 @@ -init_grad(x) = zero(x) -zero_grad!(x) = zero(x) +init_grad(x) = zero_or_nothing(x) +zero_grad!(x) = zero_or_nothing(x) zero_grad!(x::AbstractArray) = (x .= 0) scan(c::Call) = foreach(scan, c.args) @@ -32,6 +32,13 @@ back_(::Call{Missing}, Δ, once) = error("`back!` was already used") accum!(x, Δ) = x .+ Δ accum!(x::AbstractArray, Δ) = (x .+= Δ) +# This is correct, because the `.grad` field of tracked is constrained +# by a type parameter constructed from the (concrete) forward value and +# thus for `nothing` to show up here, we're guaranteed that the value +# was `nothing` during the forward pass (and thus we don't care about its +# derivatives). +accum!(x::Nothing, Δ) = x + function back(x::Tracked, Δ, once) x.isleaf && (x.grad = accum!(x.grad, Δ); return) ref = x.ref -= 1 @@ -126,7 +133,7 @@ accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ function back_(g::Grads, c::Call, Δ) Δs = c.func(Δ) - (Δs isa Tuple && length(Δs) >= length(c.args)) || + ((Δs isa Tuple || Δs isa TrackedTuple) && length(Δs) >= length(c.args)) || error("Gradient is not a tuple of length $(length(c.args))") foreach((x, Δ) -> back(g, x, Δ), c.args, Δs) end diff --git a/src/tracker/lib/array.jl b/src/tracker/lib/array.jl index f8697c88..6a2ab965 100644 --- a/src/tracker/lib/array.jl +++ b/src/tracker/lib/array.jl @@ -95,12 +95,27 @@ end Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...) -@grad function getindex(xs::AbstractArray, i...) - data(xs)[i...], function (Δ) - Δ′ = zero(xs) - Δ′[i...] = data(Δ) - (nobacksies(:getindex, Δ′), map(_->nothing, i)...) +struct ∇getindex{T,S} + xs::T + i::S +end + +function (g::∇getindex)(Δ) + Δ′ = zero(g.xs) + Δ′[g.i...] = Δ + (Δ′, map(_->nothing, g.i)...) +end +(g::∇getindex)(Δ::TrackedArray) = track(g, Δ) + +@grad function (g::∇getindex)(Δ) + z, back = g(data(Δ)), function(Δ′′) + (Δ′′[1][g.i...],) end + z, back +end + +@grad function getindex(xs::AbstractArray, i...) + data(xs)[i...], ∇getindex(xs, i) end Base.view(x::TrackedArray, inds...) = track(Base.view, x, inds...) @@ -424,10 +439,15 @@ using ForwardDiff: Dual, partials, value trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x)))) -unbroadcast(x::AbstractArray, Δ) = - size(x) == size(Δ) ? Δ : - length(x) == length(Δ) ? trim(x, Δ) : - trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ))))) +function unbroadcast(x::AbstractArray, Δ) + if size(x) == size(Δ) + return Δ + elseif length(x) == length(Δ) + return trim(x, Δ) + else + return trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ))))) + end +end unbroadcast(x::Number, Δ) = sum(Δ) unbroadcast(x::Base.RefValue, _) = nothing diff --git a/src/tracker/lib/real.jl b/src/tracker/lib/real.jl index ec57f0d3..1079c634 100644 --- a/src/tracker/lib/real.jl +++ b/src/tracker/lib/real.jl @@ -122,7 +122,9 @@ accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ) init_grad(x::Tuple) = init_grad.(x) zero_grad!(x::Tuple) = zero_grad!.(x) -track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f, zero.(xs))) +zero_or_nothing(x) = zero(x) +zero_or_nothing(x::Nothing) = nothing +track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f, zero_or_nothing.(xs))) function Base.show(io::IO, xs::TrackedTuple) show(io, data(xs)) @@ -132,11 +134,18 @@ end Base.length(x::TrackedTuple) = length(data(x)) Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i) +Base.iterate(xs::TrackedTuple) = track(iterate, xs) +Base.iterate(xs::TrackedTuple, i::Integer) = track(iterate, xs, i) +Base.iterate(xs::TrackedTuple, i::TrackedReal) = iterate(xs, data(i)) @grad function getindex(xs::TrackedTuple, i) data(xs)[i], Δ -> (ntuple(j -> i == j ? Δ : 0, length(xs)), nothing) end +@grad function iterate(xs::TrackedTuple, i=1) + (data(xs)[i], i+1), Δ -> (ntuple(j -> i == j ? Δ[1] : 0, length(xs)), nothing) +end + # Array collection function collect(xs) diff --git a/test/tracker.jl b/test/tracker.jl index 817ba389..5ed61120 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -5,6 +5,7 @@ using NNlib: conv, ∇conv_data, depthwiseconv using Printf: @sprintf using LinearAlgebra: diagm, dot, LowerTriangular, norm using Statistics: mean, std +using ForwardDiff using Random # using StatsBase @@ -36,6 +37,25 @@ gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...) @testset "indexing & slicing" begin gradtest(x->view(x, 1:2, 1:2), rand(4, 4)) + + # Nested AD for getindex + grad_tracker = gradient([1.0, 2.0, 3.0]; nest=true) do x + sum(gradient(x; nest=true) do x + sum(gradient(x; nest=true) do x + sum(x[1:2])^4 + end[1]) + end[1]) + end[1] + # We compare to ForwardDiff, since the high order derivative is not + # numerically stable under finite differencing. + grad_forward = ForwardDiff.gradient([1.0, 2.0, 3.0]) do x + sum(ForwardDiff.gradient(x) do x + sum(ForwardDiff.gradient(x) do x + sum(x[1:2])^4 + end) + end) + end + @test grad_tracker ≈ grad_forward ≈ [288.0, 288.0, 0.0] end function promotiontest(f, A, B, C)