Compare commits
1 Commits
master
...
kf/getinde
Author | SHA1 | Date |
---|---|---|
![]() |
d55f742533 |
|
@ -1,5 +1,5 @@
|
|||
init_grad(x) = zero(x)
|
||||
zero_grad!(x) = zero(x)
|
||||
init_grad(x) = zero_or_nothing(x)
|
||||
zero_grad!(x) = zero_or_nothing(x)
|
||||
zero_grad!(x::AbstractArray) = (x .= 0)
|
||||
|
||||
scan(c::Call) = foreach(scan, c.args)
|
||||
|
@ -32,6 +32,13 @@ back_(::Call{Missing}, Δ, once) = error("`back!` was already used")
|
|||
accum!(x, Δ) = x .+ Δ
|
||||
accum!(x::AbstractArray, Δ) = (x .+= Δ)
|
||||
|
||||
# This is correct, because the `.grad` field of tracked is constrained
|
||||
# by a type parameter constructed from the (concrete) forward value and
|
||||
# thus for `nothing` to show up here, we're guaranteed that the value
|
||||
# was `nothing` during the forward pass (and thus we don't care about its
|
||||
# derivatives).
|
||||
accum!(x::Nothing, Δ) = x
|
||||
|
||||
function back(x::Tracked, Δ, once)
|
||||
x.isleaf && (x.grad = accum!(x.grad, Δ); return)
|
||||
ref = x.ref -= 1
|
||||
|
@ -126,7 +133,7 @@ accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
|
|||
|
||||
function back_(g::Grads, c::Call, Δ)
|
||||
Δs = c.func(Δ)
|
||||
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
|
||||
((Δs isa Tuple || Δs isa TrackedTuple) && length(Δs) >= length(c.args)) ||
|
||||
error("Gradient is not a tuple of length $(length(c.args))")
|
||||
foreach((x, Δ) -> back(g, x, Δ), c.args, Δs)
|
||||
end
|
||||
|
|
|
@ -95,12 +95,27 @@ end
|
|||
|
||||
Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...)
|
||||
|
||||
@grad function getindex(xs::AbstractArray, i...)
|
||||
data(xs)[i...], function (Δ)
|
||||
Δ′ = zero(xs)
|
||||
Δ′[i...] = data(Δ)
|
||||
(nobacksies(:getindex, Δ′), map(_->nothing, i)...)
|
||||
struct ∇getindex{T,S}
|
||||
xs::T
|
||||
i::S
|
||||
end
|
||||
|
||||
function (g::∇getindex)(Δ)
|
||||
Δ′ = zero(g.xs)
|
||||
Δ′[g.i...] = Δ
|
||||
(Δ′, map(_->nothing, g.i)...)
|
||||
end
|
||||
(g::∇getindex)(Δ::TrackedArray) = track(g, Δ)
|
||||
|
||||
@grad function (g::∇getindex)(Δ)
|
||||
z, back = g(data(Δ)), function(Δ′′)
|
||||
(Δ′′[1][g.i...],)
|
||||
end
|
||||
z, back
|
||||
end
|
||||
|
||||
@grad function getindex(xs::AbstractArray, i...)
|
||||
data(xs)[i...], ∇getindex(xs, i)
|
||||
end
|
||||
|
||||
Base.view(x::TrackedArray, inds...) = track(Base.view, x, inds...)
|
||||
|
@ -414,10 +429,15 @@ using ForwardDiff: Dual, partials, value
|
|||
|
||||
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
|
||||
|
||||
unbroadcast(x::AbstractArray, Δ) =
|
||||
size(x) == size(Δ) ? Δ :
|
||||
length(x) == length(Δ) ? trim(x, Δ) :
|
||||
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
|
||||
function unbroadcast(x::AbstractArray, Δ)
|
||||
if size(x) == size(Δ)
|
||||
return Δ
|
||||
elseif length(x) == length(Δ)
|
||||
return trim(x, Δ)
|
||||
else
|
||||
return trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
|
||||
end
|
||||
end
|
||||
|
||||
unbroadcast(x::Number, Δ) = sum(Δ)
|
||||
unbroadcast(x::Base.RefValue, _) = nothing
|
||||
|
|
|
@ -122,7 +122,9 @@ accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ)
|
|||
init_grad(x::Tuple) = init_grad.(x)
|
||||
zero_grad!(x::Tuple) = zero_grad!.(x)
|
||||
|
||||
track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f, zero.(xs)))
|
||||
zero_or_nothing(x) = zero(x)
|
||||
zero_or_nothing(x::Nothing) = nothing
|
||||
track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f, zero_or_nothing.(xs)))
|
||||
|
||||
function Base.show(io::IO, xs::TrackedTuple)
|
||||
show(io, data(xs))
|
||||
|
@ -132,11 +134,18 @@ end
|
|||
Base.length(x::TrackedTuple) = length(data(x))
|
||||
|
||||
Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i)
|
||||
Base.iterate(xs::TrackedTuple) = track(iterate, xs)
|
||||
Base.iterate(xs::TrackedTuple, i::Integer) = track(iterate, xs, i)
|
||||
Base.iterate(xs::TrackedTuple, i::TrackedReal) = iterate(xs, data(i))
|
||||
|
||||
@grad function getindex(xs::TrackedTuple, i)
|
||||
data(xs)[i], Δ -> (ntuple(j -> i == j ? Δ : 0, length(xs)), nothing)
|
||||
end
|
||||
|
||||
@grad function iterate(xs::TrackedTuple, i=1)
|
||||
(data(xs)[i], i+1), Δ -> (ntuple(j -> i == j ? Δ[1] : 0, length(xs)), nothing)
|
||||
end
|
||||
|
||||
# Array collection
|
||||
|
||||
function collect(xs)
|
||||
|
|
|
@ -5,6 +5,7 @@ using NNlib: conv, depthwiseconv
|
|||
using Printf: @sprintf
|
||||
using LinearAlgebra: diagm, dot, LowerTriangular, norm
|
||||
using Statistics: mean, std
|
||||
using ForwardDiff
|
||||
using Random
|
||||
# using StatsBase
|
||||
|
||||
|
@ -36,6 +37,25 @@ gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
|
|||
|
||||
@testset "indexing & slicing" begin
|
||||
gradtest(x->view(x, 1:2, 1:2), rand(4, 4))
|
||||
|
||||
# Nested AD for getindex
|
||||
grad_tracker = gradient([1.0, 2.0, 3.0]; nest=true) do x
|
||||
sum(gradient(x; nest=true) do x
|
||||
sum(gradient(x; nest=true) do x
|
||||
sum(x[1:2])^4
|
||||
end[1])
|
||||
end[1])
|
||||
end[1]
|
||||
# We compare to ForwardDiff, since the high order derivative is not
|
||||
# numerically stable under finite differencing.
|
||||
grad_forward = ForwardDiff.gradient([1.0, 2.0, 3.0]) do x
|
||||
sum(ForwardDiff.gradient(x) do x
|
||||
sum(ForwardDiff.gradient(x) do x
|
||||
sum(x[1:2])^4
|
||||
end)
|
||||
end)
|
||||
end
|
||||
@test grad_tracker ≈ grad_forward ≈ [288.0, 288.0, 0.0]
|
||||
end
|
||||
|
||||
function promotiontest(f, A, B, C)
|
||||
|
|
Loading…
Reference in New Issue