From 70b5efeb4e83a1c9c317271e16181b8128b8ecfd Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Tue, 10 Jul 2018 09:03:09 +0100 Subject: [PATCH] basic nested AD --- src/tracker/Tracker.jl | 10 +++- src/tracker/array.jl | 113 ++++++++++++++++++++++------------------- src/tracker/back.jl | 10 ++-- src/tracker/numeric.jl | 2 +- src/tracker/scalar.jl | 8 +-- 5 files changed, 80 insertions(+), 63 deletions(-) diff --git a/src/tracker/Tracker.jl b/src/tracker/Tracker.jl index 5afe9ced..2a94edb7 100644 --- a/src/tracker/Tracker.jl +++ b/src/tracker/Tracker.jl @@ -47,7 +47,7 @@ track(f::Call, x) = Tracked{typeof(x)}(f) function _forward end function track(f, xs...; kw...) - y, back = _forward(f, data.(xs)...; kw...) + y, back = _forward(f, xs...; kw...) track(Call(back, tracker.(xs)), y) end @@ -97,9 +97,17 @@ checkpoint(f, args...) = track(checkpoint, f, args...) end end +nobacksies(f, x) = track(nobacksies, f, x) +nobacksies(f, xs::Tuple) = map(x -> nobacksies(f, x), xs) +@grad nobacksies(f, x) = data(x), Δ -> error("Nested AD not defined for $f") + param(x::Number) = TrackedReal(float(x)) param(xs::AbstractArray) = TrackedArray(float.(xs)) +@grad identity(x) = data(x), Δ -> (Δ,) +param(x::TrackedReal) = track(identity, x) +param(x::TrackedArray) = track(identity, x) + import NNlib.cudata import Adapt.adapt diff --git a/src/tracker/array.jl b/src/tracker/array.jl index ee8c9b39..e034a868 100644 --- a/src/tracker/array.jl +++ b/src/tracker/array.jl @@ -50,6 +50,9 @@ for f in :[Base.size, Base.ndims].args @eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...) end +Base.size(x::TrackedArray, i::Integer, j::Integer, is::Integer...) = + size(data(x), i, j, is...) + Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) = similar(data(x), dims...) @@ -65,54 +68,54 @@ Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...) @grad function getindex(xs, i...) data(xs)[i...], function (Δ) - Δ′ = zeros(xs) - Δ′[i...] = Δ - (Δ′, map(_->nothing, i)...) + Δ′ = zero(xs) + Δ′[i...] = data(Δ) + (nobacksies(:getindex, Δ′), map(_->nothing, i)...) end end Base.:-(xs::TrackedArray) = track(-, xs) -@grad -(xs) = -xs, Δ -> (-Δ,) +@grad -(xs) = -data(xs), Δ -> (-Δ,) Base.transpose(xs::TrackedArray) = track(transpose, xs) Base.ctranspose(xs::TrackedArray) = track(ctranspose, xs) -@grad transpose(xs) = xs.', Δ -> (reshape(Δ.', size(xs)),) -@grad ctranspose(xs) = xs', Δ -> (reshape(Δ', size(xs)),) +@grad transpose(xs) = data(xs).', Δ -> (reshape(Δ.', size(xs)),) +@grad ctranspose(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),) Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...) Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...) @grad function repmat(xs, m, n = 1) - repmat(xs, m, n), function (Δ) + repmat(data(xs), m, n), function (Δ) Δ′ = similar(xs) S = size(xs) - for (i,v) in enumerate(Δ) + for (i,v) in enumerate(data(Δ)) d1 = divrem(i-1, S[1]*m) x = d1[2] % S[1]+1 y = d1[1] % S[2]+1 Δ′[x, y] += v end - return (Δ′, nothing, nothing) + return (nobacksies(:repmat, Δ′), nothing, nothing) end end Base.repeat(A::TrackedArray; kw...) = track(repeat, A; kw...) @grad function repeat(xs; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A))) - repeat(xs, inner = inner, outer = outer), function (Δ) + repeat(data(xs), inner = inner, outer = outer), function (Δ) Δ′ = zero(xs) S = size(xs) # Loop through each element of Δ, calculate source dimensions, accumulate into Δ′ - for (dest_idx, val) in enumerate(IndexCartesian(), Δ) + for (dest_idx, val) in enumerate(IndexCartesian(), data(Δ)) # First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then # wrap around based on original size S. src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)] Δ′[src_idx...] += val end - (Δ′,) + (nobacksies(:repeat, Δ′),) end end @@ -143,7 +146,7 @@ for f in [:vcat, :hcat] end @grad function vcat(xs...) - vcat(xs...), function (Δ) + vcat(data.(xs)...), function (Δ) start = 0 Δs = [begin i = map(_ -> :, size(xsi)) |> Base.tail @@ -156,7 +159,7 @@ end end @grad function hcat(xs...) - hcat(xs...), function (Δ) + hcat(data.(xs)...), function (Δ) start = 0 Δs = [begin d = if ndims(xsi) == 1 @@ -176,7 +179,7 @@ Base.cat(dims, a::TrackedArray, b::AbstractArray...) = track(cat, dims, a, b...) Base.cat(dims, a::Union{RowVector,Array}, b::TrackedArray, c::AbstractArray...) = track(cat, dims, a, b, c...) @grad function cat(dims, Xs...) - cat(dims, Xs...), function (Δ) + cat(dims, data.(Xs)...), function (Δ) start = ntuple(i -> 0, Val{ndims(Δ)}) Δs = [begin dim_xs = 1:ndims(xs) @@ -194,10 +197,10 @@ Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) = reshape(xs, dims) Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Union{Int64,Colon}}}) = reshape(xs, Base._reshape_uncolon(xs, dims)) Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64}}) = track(reshape, xs, dims) -@grad reshape(xs, dims) = reshape(xs, dims), Δ -> (reshape(Δ, size(xs)),nothing) +@grad reshape(xs, dims) = reshape(data(xs), dims), Δ -> (reshape(Δ, size(xs)),nothing) Base.permutedims(xs::TrackedArray, dims) = track(permutedims, xs, dims) -@grad permutedims(xs, dims) = permutedims(xs, dims), Δ -> (permutedims(Δ, invperm(dims)),nothing) +@grad permutedims(xs, dims) = permutedims(data(xs), dims), Δ -> (permutedims(Δ, invperm(dims)),nothing) function _kron(mat1::AbstractMatrix,mat2::AbstractMatrix) m1, n1 = size(mat1) @@ -219,16 +222,18 @@ Base.sum(xs::TrackedArray, dim) = track(sum, xs, dim) Base.sum(xs::TrackedArray) = track(sum, xs) Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs)) -@grad sum(xs, dim...) = sum(xs, dim...), - Δ -> (similar(xs) .= Δ, map(_->nothing,dim)...) +@grad sum(xs, dim...) = sum(data(xs), dim...), + Δ -> (zero(xs) .+ Δ, map(_->nothing,dim)...) Base.prod(xs::TrackedArray, dim) = track(prod, xs, dim) Base.prod(xs::TrackedArray) = track(prod, xs) Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs)) -@grad prod(xs) = prod(xs), Δ -> (similar(xs) .= (prod(xs) ./ xs) .* Δ,) -@grad prod(xs, dim) = prod(xs, dim), - Δ -> (similar(xs) .= (reshape(.*(circshift.([reshape(xs, length(xs))], 1:length(xs)-1)...), size(xs))) .* Δ,nothing) +@grad prod(xs) = prod(data(xs)), Δ -> (prod(xs) ./ xs .* Δ,) +@grad prod(xs, dim) = prod(data(xs), dim), + Δ -> (nobacksies(:sum, + reshape(.*(circshift.([reshape(data(xs), length(xs))], 1:length(xs)-1)...), size(xs)) .* Δ), + nothing) Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...) @@ -244,7 +249,7 @@ LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys) LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys) LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys) -@grad dot(xs, ys) = dot(xs, ys), Δ -> (Δ .* ys, Δ .* xs) +@grad dot(xs, ys) = dot(data(xs), data(ys)), Δ -> (Δ .* ys, Δ .* xs) # Hacks to get std working Base.std(x::TrackedArray; mean = Base.mean(x)) = @@ -255,32 +260,32 @@ Base.std(x::TrackedArray, dim; mean = Base.mean(x, dim)) = Base.vecnorm(x::TrackedArray, p::Real = 2) = sum(abs.(x).^p .+ eps(0f0))^(1/p) # avoid d(sqrt(x))/dx == Inf at 0 -@grad mean(xs) = mean(xs), Δ -> (similar(xs) .= Δ ./ length(xs),) -@grad mean(xs, region) = mean(xs, region), Δ -> (similar(xs) .= Δ ./ prod(size(xs, region...)),nothing) +@grad mean(xs) = mean(data(xs)), Δ -> (Δ / length(xs),) +@grad mean(xs, region) = mean(data(xs), region), Δ -> (zero(xs) .+ Δ ./ prod(size(xs, region...)),nothing) @grad function maximum(xs, r...) - maximum(xs, r...), function (Δ) - Δ′ = zeros(xs) - _, i = findmax(xs, r...) - Δ′[i] = Δ - return (Δ′,map(_->nothing,r)...) + maximum(data(xs), r...), function (Δ) + Δ′ = zero(xs) + _, i = findmax(data(xs), r...) + Δ′[i] = data(Δ) + return (nobacksies(:maximum, Δ′),map(_->nothing,r)...) end end @grad function minimum(xs, r...) - minimum(xs, r...), function (Δ) - Δ′ = zeros(xs) - _, i = findmin(xs, r...) - Δ′[i] = Δ - return (Δ′,map(_->nothing,r)...) + minimum(data(xs), r...), function (Δ) + Δ′ = zero(xs) + _, i = findmin(data(xs), r...) + Δ′[i] = data(Δ) + return (nobacksies(:minimum, Δ′),map(_->nothing,r)...) end end # BLAS Base.diagm(x::TrackedVector) = track(diagm, x) -@grad diagm(x) = diagm(x), Δ -> (diag(Δ),) +@grad diagm(x) = diagm(data(x)), Δ -> (diag(Δ),) -for f in :[*, Ac_mul_B, A_mul_Bc].args +for f in :[*, Ac_mul_B, A_mul_Bc, A_mul_Bt, At_mul_B].args @eval begin import Base.$f $f(a::TrackedMatrix, b::TrackedMatrix) = track($f, a, b) @@ -298,10 +303,13 @@ for f in :[*, Ac_mul_B, A_mul_Bc].args end @grad a::AbstractMatrix * b::AbstractVecOrMat = - a*b, Δ -> (A_mul_Bt(Δ, b), At_mul_B(a, Δ)) + data(a)*data(b), Δ -> (A_mul_Bt(Δ, b), At_mul_B(a, Δ)) -@grad Ac_mul_B(a, b) = Ac_mul_B(a, b), Δ -> (A_mul_Bt(Δ, b)', a*Δ) -@grad A_mul_Bc(a, b) = A_mul_Bc(a, b), Δ -> (Δ * b, At_mul_B(a, Δ)') +@grad Ac_mul_B(a, b) = Ac_mul_B(data(a), data(b)), Δ -> (A_mul_Bt(Δ, b)', a*Δ) +@grad A_mul_Bc(a, b) = A_mul_Bc(data(a), data(b)), Δ -> (Δ * b, At_mul_B(a, Δ)') + +@grad At_mul_B(a, b) = At_mul_B(data(a), data(b)), Δ -> (A_mul_Bt(Δ, b)', a*Δ) +@grad A_mul_Bt(a, b) = A_mul_Bt(data(a), data(b)), Δ -> (Δ * b, At_mul_B(a, Δ)') # NNlib @@ -310,33 +318,34 @@ import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, maxpool, mea softmax(xs::TrackedArray) = track(softmax, xs) -@grad softmax(xs) = softmax(xs), Δ -> (∇softmax(Δ, xs),) +@grad softmax(xs) = softmax(data(xs)), Δ -> (nobacksies(:softmax, ∇softmax(data(Δ), data(xs))),) logsoftmax(xs::TrackedArray) = track(logsoftmax, xs) -@grad logsoftmax(xs) = logsoftmax(xs), Δ -> (∇logsoftmax(Δ, xs),) +@grad logsoftmax(xs) = logsoftmax(data(xs)), Δ -> (nobacksies(:logsoftmax, ∇logsoftmax(data(Δ), data(xs))),) conv(x::TrackedArray, w::TrackedArray; kw...) = track(conv, x, w; kw...) conv(x::AbstractArray, w::TrackedArray; kw...) = track(conv, x, w; kw...) conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...) @grad conv(x, w; kw...) = - conv(x, w; kw...), - Δ -> (NNlib.∇conv_data(Δ, x, w; kw...), - NNlib.∇conv_filter(Δ, x, w; kw...)) + conv(data(x), data(w); kw...), + Δ -> nobacksies(:conv, + (NNlib.∇conv_data(data.((Δ, x, w))...; kw...), + NNlib.∇conv_filter(data.((Δ, x, w))...; kw...))) maxpool(x::TrackedArray, k; kw...) = track(maxpool, x, k; kw...) @grad function maxpool(x, k; kw...) - y = maxpool(x, k; kw...) - y, Δ -> (NNlib.∇maxpool(Δ, y, x, k; kw...), nothing) + y = maxpool(data(x), k; kw...) + y, Δ -> (nobacksies(:maxpool, NNlib.∇maxpool(data.((Δ, y, x))..., k; kw...)), nothing) end meanpool(x::TrackedArray, k; kw...) = track(meanpool, x, k; kw...) @grad function meanpool(x, k; kw...) - y = meanpool(x, k; kw...) - y, Δ -> (NNlib.∇meanpool(Δ, y, x, k; kw...), nothing) + y = meanpool(data(x), k; kw...) + y, Δ -> (nobacksies(:maxpool, NNlib.∇meanpool(data.((Δ, y, x))..., k; kw...)), nothing) end # Broadcasting @@ -364,9 +373,11 @@ function ∇broadcast(f, args::Vararg{Any,N}) where N out = broadcast(f, dargs...) eltype(out) <: Dual || return out y = value.(out) - back = function (Δ) + back = function (Δ_) + Δ = data(Δ_) Δargs = ntuple(i -> getpartial.(Δ, out, i), Val{N}) - map((x, Δ) -> unbroadcast(x, Δ), sizes, Δargs) + dxs = map((x, Δ) -> unbroadcast(x, Δ), sizes, Δargs) + nobacksies(:broadcast, dxs) end # So we can return non-tracked arrays track(Call(back, tracker.(args)), y) diff --git a/src/tracker/back.jl b/src/tracker/back.jl index c6d1646a..8e492861 100644 --- a/src/tracker/back.jl +++ b/src/tracker/back.jl @@ -23,7 +23,7 @@ function back_(c::Call, Δ) Δs = c.func(Δ) (Δs isa Tuple && length(Δs) >= length(c.args)) || error("Gradient is not a tuple of length $(length(c.args))") - foreach(back, c.args, Δs) + foreach(back, c.args, data.(Δs)) end back_(::Call{Void}, Δ) = nothing @@ -78,6 +78,8 @@ end Grads() = Grads(ObjectIdDict()) +Grads(ps::Params) = Grads(ObjectIdDict(tracker(p) => init_grad(data(p)) for p in ps)) + Base.getindex(g::Grads, x::Tracked) = g.grads[x] function Base.getindex(g::Grads, x) istracked(x) || error("Object not tracked: $x") @@ -114,15 +116,11 @@ back(::Grads, ::Void, _) = return function forward(f, ps::Params) y = f() y, function (Δ) - g = Grads() + g = Grads(ps) if istracked(y) scan(y) back(g, tracker(y), Δ) end - for p in ps - haskey(g, tracker(p)) || - (g[tracker(p)] = init_grad(data(p))) - end return g end end diff --git a/src/tracker/numeric.jl b/src/tracker/numeric.jl index e0028b7c..1ad872e4 100644 --- a/src/tracker/numeric.jl +++ b/src/tracker/numeric.jl @@ -15,4 +15,4 @@ end gradcheck(f, xs...) = all(isapprox.(ngradient(f, xs...), - gradient(f, xs...), rtol = 1e-5, atol = 1e-5)) + data.(gradient(f, xs...)), rtol = 1e-5, atol = 1e-5)) diff --git a/src/tracker/scalar.jl b/src/tracker/scalar.jl index 6232807f..7e574fd9 100644 --- a/src/tracker/scalar.jl +++ b/src/tracker/scalar.jl @@ -50,17 +50,17 @@ for (M, f, arity) in DiffRules.diffrules() arity == 1 || continue @eval begin @grad $M.$f(a::Real) = - $M.$f(a), Δ -> (Δ * $(DiffRules.diffrule(M, f, :(data(a)))),) + $M.$f(data(a)), Δ -> (Δ * $(DiffRules.diffrule(M, f, :a)),) $M.$f(a::TrackedReal) = track($M.$f, a) end end for (M, f, arity) in DiffRules.diffrules() arity == 2 || continue - da, db = DiffRules.diffrule(M, f, :(data(a)), :(data(b))) + da, db = DiffRules.diffrule(M, f, :a, :b) f = :($M.$f) @eval begin - @grad $f(a::Real, b::Real) = $f(a, b), Δ -> (Δ * $da, Δ * $db) + @grad $f(a::Real, b::Real) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db) $f(a::TrackedReal, b::TrackedReal) = track($f, a, b) $f(a::TrackedReal, b::Real) = track($f, a, b) $f(a::Real, b::TrackedReal) = track($f, a, b) @@ -111,5 +111,5 @@ function scan(c::Call{typeof(collect)}) end function back_(c::Call{typeof(collect)}, Δ) - foreach(back, c.args[1], Δ) + foreach(back, c.args[1], data(Δ)) end