Compare commits
1 Commits
master
...
kf/getinde
Author | SHA1 | Date | |
---|---|---|---|
![]() |
d55f742533 |
@ -1,5 +1,5 @@
|
|||||||
init_grad(x) = zero(x)
|
init_grad(x) = zero_or_nothing(x)
|
||||||
zero_grad!(x) = zero(x)
|
zero_grad!(x) = zero_or_nothing(x)
|
||||||
zero_grad!(x::AbstractArray) = (x .= 0)
|
zero_grad!(x::AbstractArray) = (x .= 0)
|
||||||
|
|
||||||
scan(c::Call) = foreach(scan, c.args)
|
scan(c::Call) = foreach(scan, c.args)
|
||||||
@ -32,6 +32,13 @@ back_(::Call{Missing}, Δ, once) = error("`back!` was already used")
|
|||||||
accum!(x, Δ) = x .+ Δ
|
accum!(x, Δ) = x .+ Δ
|
||||||
accum!(x::AbstractArray, Δ) = (x .+= Δ)
|
accum!(x::AbstractArray, Δ) = (x .+= Δ)
|
||||||
|
|
||||||
|
# This is correct, because the `.grad` field of tracked is constrained
|
||||||
|
# by a type parameter constructed from the (concrete) forward value and
|
||||||
|
# thus for `nothing` to show up here, we're guaranteed that the value
|
||||||
|
# was `nothing` during the forward pass (and thus we don't care about its
|
||||||
|
# derivatives).
|
||||||
|
accum!(x::Nothing, Δ) = x
|
||||||
|
|
||||||
function back(x::Tracked, Δ, once)
|
function back(x::Tracked, Δ, once)
|
||||||
x.isleaf && (x.grad = accum!(x.grad, Δ); return)
|
x.isleaf && (x.grad = accum!(x.grad, Δ); return)
|
||||||
ref = x.ref -= 1
|
ref = x.ref -= 1
|
||||||
@ -126,7 +133,7 @@ accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] .+ Δ : Δ
|
|||||||
|
|
||||||
function back_(g::Grads, c::Call, Δ)
|
function back_(g::Grads, c::Call, Δ)
|
||||||
Δs = c.func(Δ)
|
Δs = c.func(Δ)
|
||||||
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
|
((Δs isa Tuple || Δs isa TrackedTuple) && length(Δs) >= length(c.args)) ||
|
||||||
error("Gradient is not a tuple of length $(length(c.args))")
|
error("Gradient is not a tuple of length $(length(c.args))")
|
||||||
foreach((x, Δ) -> back(g, x, Δ), c.args, Δs)
|
foreach((x, Δ) -> back(g, x, Δ), c.args, Δs)
|
||||||
end
|
end
|
||||||
|
@ -95,12 +95,27 @@ end
|
|||||||
|
|
||||||
Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...)
|
Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...)
|
||||||
|
|
||||||
@grad function getindex(xs::AbstractArray, i...)
|
struct ∇getindex{T,S}
|
||||||
data(xs)[i...], function (Δ)
|
xs::T
|
||||||
Δ′ = zero(xs)
|
i::S
|
||||||
Δ′[i...] = data(Δ)
|
end
|
||||||
(nobacksies(:getindex, Δ′), map(_->nothing, i)...)
|
|
||||||
|
function (g::∇getindex)(Δ)
|
||||||
|
Δ′ = zero(g.xs)
|
||||||
|
Δ′[g.i...] = Δ
|
||||||
|
(Δ′, map(_->nothing, g.i)...)
|
||||||
|
end
|
||||||
|
(g::∇getindex)(Δ::TrackedArray) = track(g, Δ)
|
||||||
|
|
||||||
|
@grad function (g::∇getindex)(Δ)
|
||||||
|
z, back = g(data(Δ)), function(Δ′′)
|
||||||
|
(Δ′′[1][g.i...],)
|
||||||
end
|
end
|
||||||
|
z, back
|
||||||
|
end
|
||||||
|
|
||||||
|
@grad function getindex(xs::AbstractArray, i...)
|
||||||
|
data(xs)[i...], ∇getindex(xs, i)
|
||||||
end
|
end
|
||||||
|
|
||||||
Base.view(x::TrackedArray, inds...) = track(Base.view, x, inds...)
|
Base.view(x::TrackedArray, inds...) = track(Base.view, x, inds...)
|
||||||
@ -414,10 +429,15 @@ using ForwardDiff: Dual, partials, value
|
|||||||
|
|
||||||
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
|
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val(ndims(x))))
|
||||||
|
|
||||||
unbroadcast(x::AbstractArray, Δ) =
|
function unbroadcast(x::AbstractArray, Δ)
|
||||||
size(x) == size(Δ) ? Δ :
|
if size(x) == size(Δ)
|
||||||
length(x) == length(Δ) ? trim(x, Δ) :
|
return Δ
|
||||||
trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
|
elseif length(x) == length(Δ)
|
||||||
|
return trim(x, Δ)
|
||||||
|
else
|
||||||
|
return trim(x, sum(Δ, dims = ntuple(i -> size(x, i) == 1 ? i : ndims(Δ)+1, Val(ndims(Δ)))))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
unbroadcast(x::Number, Δ) = sum(Δ)
|
unbroadcast(x::Number, Δ) = sum(Δ)
|
||||||
unbroadcast(x::Base.RefValue, _) = nothing
|
unbroadcast(x::Base.RefValue, _) = nothing
|
||||||
|
@ -122,7 +122,9 @@ accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ)
|
|||||||
init_grad(x::Tuple) = init_grad.(x)
|
init_grad(x::Tuple) = init_grad.(x)
|
||||||
zero_grad!(x::Tuple) = zero_grad!.(x)
|
zero_grad!(x::Tuple) = zero_grad!.(x)
|
||||||
|
|
||||||
track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f, zero.(xs)))
|
zero_or_nothing(x) = zero(x)
|
||||||
|
zero_or_nothing(x::Nothing) = nothing
|
||||||
|
track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f, zero_or_nothing.(xs)))
|
||||||
|
|
||||||
function Base.show(io::IO, xs::TrackedTuple)
|
function Base.show(io::IO, xs::TrackedTuple)
|
||||||
show(io, data(xs))
|
show(io, data(xs))
|
||||||
@ -132,11 +134,18 @@ end
|
|||||||
Base.length(x::TrackedTuple) = length(data(x))
|
Base.length(x::TrackedTuple) = length(data(x))
|
||||||
|
|
||||||
Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i)
|
Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i)
|
||||||
|
Base.iterate(xs::TrackedTuple) = track(iterate, xs)
|
||||||
|
Base.iterate(xs::TrackedTuple, i::Integer) = track(iterate, xs, i)
|
||||||
|
Base.iterate(xs::TrackedTuple, i::TrackedReal) = iterate(xs, data(i))
|
||||||
|
|
||||||
@grad function getindex(xs::TrackedTuple, i)
|
@grad function getindex(xs::TrackedTuple, i)
|
||||||
data(xs)[i], Δ -> (ntuple(j -> i == j ? Δ : 0, length(xs)), nothing)
|
data(xs)[i], Δ -> (ntuple(j -> i == j ? Δ : 0, length(xs)), nothing)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@grad function iterate(xs::TrackedTuple, i=1)
|
||||||
|
(data(xs)[i], i+1), Δ -> (ntuple(j -> i == j ? Δ[1] : 0, length(xs)), nothing)
|
||||||
|
end
|
||||||
|
|
||||||
# Array collection
|
# Array collection
|
||||||
|
|
||||||
function collect(xs)
|
function collect(xs)
|
||||||
|
@ -5,6 +5,7 @@ using NNlib: conv, depthwiseconv
|
|||||||
using Printf: @sprintf
|
using Printf: @sprintf
|
||||||
using LinearAlgebra: diagm, dot, LowerTriangular, norm
|
using LinearAlgebra: diagm, dot, LowerTriangular, norm
|
||||||
using Statistics: mean, std
|
using Statistics: mean, std
|
||||||
|
using ForwardDiff
|
||||||
using Random
|
using Random
|
||||||
# using StatsBase
|
# using StatsBase
|
||||||
|
|
||||||
@ -36,6 +37,25 @@ gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...)
|
|||||||
|
|
||||||
@testset "indexing & slicing" begin
|
@testset "indexing & slicing" begin
|
||||||
gradtest(x->view(x, 1:2, 1:2), rand(4, 4))
|
gradtest(x->view(x, 1:2, 1:2), rand(4, 4))
|
||||||
|
|
||||||
|
# Nested AD for getindex
|
||||||
|
grad_tracker = gradient([1.0, 2.0, 3.0]; nest=true) do x
|
||||||
|
sum(gradient(x; nest=true) do x
|
||||||
|
sum(gradient(x; nest=true) do x
|
||||||
|
sum(x[1:2])^4
|
||||||
|
end[1])
|
||||||
|
end[1])
|
||||||
|
end[1]
|
||||||
|
# We compare to ForwardDiff, since the high order derivative is not
|
||||||
|
# numerically stable under finite differencing.
|
||||||
|
grad_forward = ForwardDiff.gradient([1.0, 2.0, 3.0]) do x
|
||||||
|
sum(ForwardDiff.gradient(x) do x
|
||||||
|
sum(ForwardDiff.gradient(x) do x
|
||||||
|
sum(x[1:2])^4
|
||||||
|
end)
|
||||||
|
end)
|
||||||
|
end
|
||||||
|
@test grad_tracker ≈ grad_forward ≈ [288.0, 288.0, 0.0]
|
||||||
end
|
end
|
||||||
|
|
||||||
function promotiontest(f, A, B, C)
|
function promotiontest(f, A, B, C)
|
||||||
|
Loading…
Reference in New Issue
Block a user