basic nested AD
This commit is contained in:
parent
80af9a3830
commit
70b5efeb4e
@ -47,7 +47,7 @@ track(f::Call, x) = Tracked{typeof(x)}(f)
|
||||
function _forward end
|
||||
|
||||
function track(f, xs...; kw...)
|
||||
y, back = _forward(f, data.(xs)...; kw...)
|
||||
y, back = _forward(f, xs...; kw...)
|
||||
track(Call(back, tracker.(xs)), y)
|
||||
end
|
||||
|
||||
@ -97,9 +97,17 @@ checkpoint(f, args...) = track(checkpoint, f, args...)
|
||||
end
|
||||
end
|
||||
|
||||
nobacksies(f, x) = track(nobacksies, f, x)
|
||||
nobacksies(f, xs::Tuple) = map(x -> nobacksies(f, x), xs)
|
||||
@grad nobacksies(f, x) = data(x), Δ -> error("Nested AD not defined for $f")
|
||||
|
||||
param(x::Number) = TrackedReal(float(x))
|
||||
param(xs::AbstractArray) = TrackedArray(float.(xs))
|
||||
|
||||
@grad identity(x) = data(x), Δ -> (Δ,)
|
||||
param(x::TrackedReal) = track(identity, x)
|
||||
param(x::TrackedArray) = track(identity, x)
|
||||
|
||||
import NNlib.cudata
|
||||
import Adapt.adapt
|
||||
|
||||
|
@ -50,6 +50,9 @@ for f in :[Base.size, Base.ndims].args
|
||||
@eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...)
|
||||
end
|
||||
|
||||
Base.size(x::TrackedArray, i::Integer, j::Integer, is::Integer...) =
|
||||
size(data(x), i, j, is...)
|
||||
|
||||
Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
|
||||
similar(data(x), dims...)
|
||||
|
||||
@ -65,54 +68,54 @@ Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...)
|
||||
|
||||
@grad function getindex(xs, i...)
|
||||
data(xs)[i...], function (Δ)
|
||||
Δ′ = zeros(xs)
|
||||
Δ′[i...] = Δ
|
||||
(Δ′, map(_->nothing, i)...)
|
||||
Δ′ = zero(xs)
|
||||
Δ′[i...] = data(Δ)
|
||||
(nobacksies(:getindex, Δ′), map(_->nothing, i)...)
|
||||
end
|
||||
end
|
||||
|
||||
Base.:-(xs::TrackedArray) = track(-, xs)
|
||||
|
||||
@grad -(xs) = -xs, Δ -> (-Δ,)
|
||||
@grad -(xs) = -data(xs), Δ -> (-Δ,)
|
||||
|
||||
Base.transpose(xs::TrackedArray) = track(transpose, xs)
|
||||
Base.ctranspose(xs::TrackedArray) = track(ctranspose, xs)
|
||||
|
||||
@grad transpose(xs) = xs.', Δ -> (reshape(Δ.', size(xs)),)
|
||||
@grad ctranspose(xs) = xs', Δ -> (reshape(Δ', size(xs)),)
|
||||
@grad transpose(xs) = data(xs).', Δ -> (reshape(Δ.', size(xs)),)
|
||||
@grad ctranspose(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),)
|
||||
|
||||
Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...)
|
||||
Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...)
|
||||
|
||||
@grad function repmat(xs, m, n = 1)
|
||||
repmat(xs, m, n), function (Δ)
|
||||
repmat(data(xs), m, n), function (Δ)
|
||||
Δ′ = similar(xs)
|
||||
S = size(xs)
|
||||
for (i,v) in enumerate(Δ)
|
||||
for (i,v) in enumerate(data(Δ))
|
||||
d1 = divrem(i-1, S[1]*m)
|
||||
x = d1[2] % S[1]+1
|
||||
y = d1[1] % S[2]+1
|
||||
Δ′[x, y] += v
|
||||
end
|
||||
return (Δ′, nothing, nothing)
|
||||
return (nobacksies(:repmat, Δ′), nothing, nothing)
|
||||
end
|
||||
end
|
||||
|
||||
Base.repeat(A::TrackedArray; kw...) = track(repeat, A; kw...)
|
||||
|
||||
@grad function repeat(xs; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A)))
|
||||
repeat(xs, inner = inner, outer = outer), function (Δ)
|
||||
repeat(data(xs), inner = inner, outer = outer), function (Δ)
|
||||
Δ′ = zero(xs)
|
||||
S = size(xs)
|
||||
|
||||
# Loop through each element of Δ, calculate source dimensions, accumulate into Δ′
|
||||
for (dest_idx, val) in enumerate(IndexCartesian(), Δ)
|
||||
for (dest_idx, val) in enumerate(IndexCartesian(), data(Δ))
|
||||
# First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then
|
||||
# wrap around based on original size S.
|
||||
src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)]
|
||||
Δ′[src_idx...] += val
|
||||
end
|
||||
(Δ′,)
|
||||
(nobacksies(:repeat, Δ′),)
|
||||
end
|
||||
end
|
||||
|
||||
@ -143,7 +146,7 @@ for f in [:vcat, :hcat]
|
||||
end
|
||||
|
||||
@grad function vcat(xs...)
|
||||
vcat(xs...), function (Δ)
|
||||
vcat(data.(xs)...), function (Δ)
|
||||
start = 0
|
||||
Δs = [begin
|
||||
i = map(_ -> :, size(xsi)) |> Base.tail
|
||||
@ -156,7 +159,7 @@ end
|
||||
end
|
||||
|
||||
@grad function hcat(xs...)
|
||||
hcat(xs...), function (Δ)
|
||||
hcat(data.(xs)...), function (Δ)
|
||||
start = 0
|
||||
Δs = [begin
|
||||
d = if ndims(xsi) == 1
|
||||
@ -176,7 +179,7 @@ Base.cat(dims, a::TrackedArray, b::AbstractArray...) = track(cat, dims, a, b...)
|
||||
Base.cat(dims, a::Union{RowVector,Array}, b::TrackedArray, c::AbstractArray...) = track(cat, dims, a, b, c...)
|
||||
|
||||
@grad function cat(dims, Xs...)
|
||||
cat(dims, Xs...), function (Δ)
|
||||
cat(dims, data.(Xs)...), function (Δ)
|
||||
start = ntuple(i -> 0, Val{ndims(Δ)})
|
||||
Δs = [begin
|
||||
dim_xs = 1:ndims(xs)
|
||||
@ -194,10 +197,10 @@ Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) = reshape(xs, dims)
|
||||
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Union{Int64,Colon}}}) = reshape(xs, Base._reshape_uncolon(xs, dims))
|
||||
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64}}) = track(reshape, xs, dims)
|
||||
|
||||
@grad reshape(xs, dims) = reshape(xs, dims), Δ -> (reshape(Δ, size(xs)),nothing)
|
||||
@grad reshape(xs, dims) = reshape(data(xs), dims), Δ -> (reshape(Δ, size(xs)),nothing)
|
||||
|
||||
Base.permutedims(xs::TrackedArray, dims) = track(permutedims, xs, dims)
|
||||
@grad permutedims(xs, dims) = permutedims(xs, dims), Δ -> (permutedims(Δ, invperm(dims)),nothing)
|
||||
@grad permutedims(xs, dims) = permutedims(data(xs), dims), Δ -> (permutedims(Δ, invperm(dims)),nothing)
|
||||
|
||||
function _kron(mat1::AbstractMatrix,mat2::AbstractMatrix)
|
||||
m1, n1 = size(mat1)
|
||||
@ -219,16 +222,18 @@ Base.sum(xs::TrackedArray, dim) = track(sum, xs, dim)
|
||||
Base.sum(xs::TrackedArray) = track(sum, xs)
|
||||
Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs))
|
||||
|
||||
@grad sum(xs, dim...) = sum(xs, dim...),
|
||||
Δ -> (similar(xs) .= Δ, map(_->nothing,dim)...)
|
||||
@grad sum(xs, dim...) = sum(data(xs), dim...),
|
||||
Δ -> (zero(xs) .+ Δ, map(_->nothing,dim)...)
|
||||
|
||||
Base.prod(xs::TrackedArray, dim) = track(prod, xs, dim)
|
||||
Base.prod(xs::TrackedArray) = track(prod, xs)
|
||||
Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs))
|
||||
|
||||
@grad prod(xs) = prod(xs), Δ -> (similar(xs) .= (prod(xs) ./ xs) .* Δ,)
|
||||
@grad prod(xs, dim) = prod(xs, dim),
|
||||
Δ -> (similar(xs) .= (reshape(.*(circshift.([reshape(xs, length(xs))], 1:length(xs)-1)...), size(xs))) .* Δ,nothing)
|
||||
@grad prod(xs) = prod(data(xs)), Δ -> (prod(xs) ./ xs .* Δ,)
|
||||
@grad prod(xs, dim) = prod(data(xs), dim),
|
||||
Δ -> (nobacksies(:sum,
|
||||
reshape(.*(circshift.([reshape(data(xs), length(xs))], 1:length(xs)-1)...), size(xs)) .* Δ),
|
||||
nothing)
|
||||
|
||||
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
|
||||
|
||||
@ -244,7 +249,7 @@ LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys)
|
||||
LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys)
|
||||
LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
|
||||
|
||||
@grad dot(xs, ys) = dot(xs, ys), Δ -> (Δ .* ys, Δ .* xs)
|
||||
@grad dot(xs, ys) = dot(data(xs), data(ys)), Δ -> (Δ .* ys, Δ .* xs)
|
||||
|
||||
# Hacks to get std working
|
||||
Base.std(x::TrackedArray; mean = Base.mean(x)) =
|
||||
@ -255,32 +260,32 @@ Base.std(x::TrackedArray, dim; mean = Base.mean(x, dim)) =
|
||||
Base.vecnorm(x::TrackedArray, p::Real = 2) =
|
||||
sum(abs.(x).^p .+ eps(0f0))^(1/p) # avoid d(sqrt(x))/dx == Inf at 0
|
||||
|
||||
@grad mean(xs) = mean(xs), Δ -> (similar(xs) .= Δ ./ length(xs),)
|
||||
@grad mean(xs, region) = mean(xs, region), Δ -> (similar(xs) .= Δ ./ prod(size(xs, region...)),nothing)
|
||||
@grad mean(xs) = mean(data(xs)), Δ -> (Δ / length(xs),)
|
||||
@grad mean(xs, region) = mean(data(xs), region), Δ -> (zero(xs) .+ Δ ./ prod(size(xs, region...)),nothing)
|
||||
|
||||
@grad function maximum(xs, r...)
|
||||
maximum(xs, r...), function (Δ)
|
||||
Δ′ = zeros(xs)
|
||||
_, i = findmax(xs, r...)
|
||||
Δ′[i] = Δ
|
||||
return (Δ′,map(_->nothing,r)...)
|
||||
maximum(data(xs), r...), function (Δ)
|
||||
Δ′ = zero(xs)
|
||||
_, i = findmax(data(xs), r...)
|
||||
Δ′[i] = data(Δ)
|
||||
return (nobacksies(:maximum, Δ′),map(_->nothing,r)...)
|
||||
end
|
||||
end
|
||||
@grad function minimum(xs, r...)
|
||||
minimum(xs, r...), function (Δ)
|
||||
Δ′ = zeros(xs)
|
||||
_, i = findmin(xs, r...)
|
||||
Δ′[i] = Δ
|
||||
return (Δ′,map(_->nothing,r)...)
|
||||
minimum(data(xs), r...), function (Δ)
|
||||
Δ′ = zero(xs)
|
||||
_, i = findmin(data(xs), r...)
|
||||
Δ′[i] = data(Δ)
|
||||
return (nobacksies(:minimum, Δ′),map(_->nothing,r)...)
|
||||
end
|
||||
end
|
||||
|
||||
# BLAS
|
||||
|
||||
Base.diagm(x::TrackedVector) = track(diagm, x)
|
||||
@grad diagm(x) = diagm(x), Δ -> (diag(Δ),)
|
||||
@grad diagm(x) = diagm(data(x)), Δ -> (diag(Δ),)
|
||||
|
||||
for f in :[*, Ac_mul_B, A_mul_Bc].args
|
||||
for f in :[*, Ac_mul_B, A_mul_Bc, A_mul_Bt, At_mul_B].args
|
||||
@eval begin
|
||||
import Base.$f
|
||||
$f(a::TrackedMatrix, b::TrackedMatrix) = track($f, a, b)
|
||||
@ -298,10 +303,13 @@ for f in :[*, Ac_mul_B, A_mul_Bc].args
|
||||
end
|
||||
|
||||
@grad a::AbstractMatrix * b::AbstractVecOrMat =
|
||||
a*b, Δ -> (A_mul_Bt(Δ, b), At_mul_B(a, Δ))
|
||||
data(a)*data(b), Δ -> (A_mul_Bt(Δ, b), At_mul_B(a, Δ))
|
||||
|
||||
@grad Ac_mul_B(a, b) = Ac_mul_B(a, b), Δ -> (A_mul_Bt(Δ, b)', a*Δ)
|
||||
@grad A_mul_Bc(a, b) = A_mul_Bc(a, b), Δ -> (Δ * b, At_mul_B(a, Δ)')
|
||||
@grad Ac_mul_B(a, b) = Ac_mul_B(data(a), data(b)), Δ -> (A_mul_Bt(Δ, b)', a*Δ)
|
||||
@grad A_mul_Bc(a, b) = A_mul_Bc(data(a), data(b)), Δ -> (Δ * b, At_mul_B(a, Δ)')
|
||||
|
||||
@grad At_mul_B(a, b) = At_mul_B(data(a), data(b)), Δ -> (A_mul_Bt(Δ, b)', a*Δ)
|
||||
@grad A_mul_Bt(a, b) = A_mul_Bt(data(a), data(b)), Δ -> (Δ * b, At_mul_B(a, Δ)')
|
||||
|
||||
# NNlib
|
||||
|
||||
@ -310,33 +318,34 @@ import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, maxpool, mea
|
||||
|
||||
softmax(xs::TrackedArray) = track(softmax, xs)
|
||||
|
||||
@grad softmax(xs) = softmax(xs), Δ -> (∇softmax(Δ, xs),)
|
||||
@grad softmax(xs) = softmax(data(xs)), Δ -> (nobacksies(:softmax, ∇softmax(data(Δ), data(xs))),)
|
||||
|
||||
logsoftmax(xs::TrackedArray) = track(logsoftmax, xs)
|
||||
|
||||
@grad logsoftmax(xs) = logsoftmax(xs), Δ -> (∇logsoftmax(Δ, xs),)
|
||||
@grad logsoftmax(xs) = logsoftmax(data(xs)), Δ -> (nobacksies(:logsoftmax, ∇logsoftmax(data(Δ), data(xs))),)
|
||||
|
||||
conv(x::TrackedArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
|
||||
conv(x::AbstractArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
|
||||
conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...)
|
||||
|
||||
@grad conv(x, w; kw...) =
|
||||
conv(x, w; kw...),
|
||||
Δ -> (NNlib.∇conv_data(Δ, x, w; kw...),
|
||||
NNlib.∇conv_filter(Δ, x, w; kw...))
|
||||
conv(data(x), data(w); kw...),
|
||||
Δ -> nobacksies(:conv,
|
||||
(NNlib.∇conv_data(data.((Δ, x, w))...; kw...),
|
||||
NNlib.∇conv_filter(data.((Δ, x, w))...; kw...)))
|
||||
|
||||
maxpool(x::TrackedArray, k; kw...) = track(maxpool, x, k; kw...)
|
||||
|
||||
@grad function maxpool(x, k; kw...)
|
||||
y = maxpool(x, k; kw...)
|
||||
y, Δ -> (NNlib.∇maxpool(Δ, y, x, k; kw...), nothing)
|
||||
y = maxpool(data(x), k; kw...)
|
||||
y, Δ -> (nobacksies(:maxpool, NNlib.∇maxpool(data.((Δ, y, x))..., k; kw...)), nothing)
|
||||
end
|
||||
|
||||
meanpool(x::TrackedArray, k; kw...) = track(meanpool, x, k; kw...)
|
||||
|
||||
@grad function meanpool(x, k; kw...)
|
||||
y = meanpool(x, k; kw...)
|
||||
y, Δ -> (NNlib.∇meanpool(Δ, y, x, k; kw...), nothing)
|
||||
y = meanpool(data(x), k; kw...)
|
||||
y, Δ -> (nobacksies(:maxpool, NNlib.∇meanpool(data.((Δ, y, x))..., k; kw...)), nothing)
|
||||
end
|
||||
|
||||
# Broadcasting
|
||||
@ -364,9 +373,11 @@ function ∇broadcast(f, args::Vararg{Any,N}) where N
|
||||
out = broadcast(f, dargs...)
|
||||
eltype(out) <: Dual || return out
|
||||
y = value.(out)
|
||||
back = function (Δ)
|
||||
back = function (Δ_)
|
||||
Δ = data(Δ_)
|
||||
Δargs = ntuple(i -> getpartial.(Δ, out, i), Val{N})
|
||||
map((x, Δ) -> unbroadcast(x, Δ), sizes, Δargs)
|
||||
dxs = map((x, Δ) -> unbroadcast(x, Δ), sizes, Δargs)
|
||||
nobacksies(:broadcast, dxs)
|
||||
end
|
||||
# So we can return non-tracked arrays
|
||||
track(Call(back, tracker.(args)), y)
|
||||
|
@ -23,7 +23,7 @@ function back_(c::Call, Δ)
|
||||
Δs = c.func(Δ)
|
||||
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
|
||||
error("Gradient is not a tuple of length $(length(c.args))")
|
||||
foreach(back, c.args, Δs)
|
||||
foreach(back, c.args, data.(Δs))
|
||||
end
|
||||
|
||||
back_(::Call{Void}, Δ) = nothing
|
||||
@ -78,6 +78,8 @@ end
|
||||
|
||||
Grads() = Grads(ObjectIdDict())
|
||||
|
||||
Grads(ps::Params) = Grads(ObjectIdDict(tracker(p) => init_grad(data(p)) for p in ps))
|
||||
|
||||
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
|
||||
function Base.getindex(g::Grads, x)
|
||||
istracked(x) || error("Object not tracked: $x")
|
||||
@ -114,15 +116,11 @@ back(::Grads, ::Void, _) = return
|
||||
function forward(f, ps::Params)
|
||||
y = f()
|
||||
y, function (Δ)
|
||||
g = Grads()
|
||||
g = Grads(ps)
|
||||
if istracked(y)
|
||||
scan(y)
|
||||
back(g, tracker(y), Δ)
|
||||
end
|
||||
for p in ps
|
||||
haskey(g, tracker(p)) ||
|
||||
(g[tracker(p)] = init_grad(data(p)))
|
||||
end
|
||||
return g
|
||||
end
|
||||
end
|
||||
|
@ -15,4 +15,4 @@ end
|
||||
|
||||
gradcheck(f, xs...) =
|
||||
all(isapprox.(ngradient(f, xs...),
|
||||
gradient(f, xs...), rtol = 1e-5, atol = 1e-5))
|
||||
data.(gradient(f, xs...)), rtol = 1e-5, atol = 1e-5))
|
||||
|
@ -50,17 +50,17 @@ for (M, f, arity) in DiffRules.diffrules()
|
||||
arity == 1 || continue
|
||||
@eval begin
|
||||
@grad $M.$f(a::Real) =
|
||||
$M.$f(a), Δ -> (Δ * $(DiffRules.diffrule(M, f, :(data(a)))),)
|
||||
$M.$f(data(a)), Δ -> (Δ * $(DiffRules.diffrule(M, f, :a)),)
|
||||
$M.$f(a::TrackedReal) = track($M.$f, a)
|
||||
end
|
||||
end
|
||||
|
||||
for (M, f, arity) in DiffRules.diffrules()
|
||||
arity == 2 || continue
|
||||
da, db = DiffRules.diffrule(M, f, :(data(a)), :(data(b)))
|
||||
da, db = DiffRules.diffrule(M, f, :a, :b)
|
||||
f = :($M.$f)
|
||||
@eval begin
|
||||
@grad $f(a::Real, b::Real) = $f(a, b), Δ -> (Δ * $da, Δ * $db)
|
||||
@grad $f(a::Real, b::Real) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db)
|
||||
$f(a::TrackedReal, b::TrackedReal) = track($f, a, b)
|
||||
$f(a::TrackedReal, b::Real) = track($f, a, b)
|
||||
$f(a::Real, b::TrackedReal) = track($f, a, b)
|
||||
@ -111,5 +111,5 @@ function scan(c::Call{typeof(collect)})
|
||||
end
|
||||
|
||||
function back_(c::Call{typeof(collect)}, Δ)
|
||||
foreach(back, c.args[1], Δ)
|
||||
foreach(back, c.args[1], data(Δ))
|
||||
end
|
||||
|
Loading…
Reference in New Issue
Block a user