basic nested AD

This commit is contained in:
Mike J Innes 2018-07-09 17:11:11 +01:00
parent e763c342ee
commit 7b03c8ac22
4 changed files with 36 additions and 34 deletions

View File

@ -47,7 +47,7 @@ track(f::Call, x) = Tracked{typeof(x)}(f)
function _forward end function _forward end
function track(f, xs...; kw...) function track(f, xs...; kw...)
y, back = _forward(f, data.(xs)...; kw...) y, back = _forward(f, xs...; kw...)
track(Call(back, tracker.(xs)), y) track(Call(back, tracker.(xs)), y)
end end
@ -100,6 +100,10 @@ end
param(x::Number) = TrackedReal(float(x)) param(x::Number) = TrackedReal(float(x))
param(xs::AbstractArray) = TrackedArray(float.(xs)) param(xs::AbstractArray) = TrackedArray(float.(xs))
@grad identity(x) = data(x), Δ -> (Δ,)
param(x::TrackedReal) = track(identity, x)
param(x::TrackedArray) = track(identity, x)
import NNlib.cudata import NNlib.cudata
import Adapt.adapt import Adapt.adapt

View File

@ -73,19 +73,19 @@ end
Base.:-(xs::TrackedArray) = track(-, xs) Base.:-(xs::TrackedArray) = track(-, xs)
@grad -(xs) = -xs, Δ -> (-Δ,) @grad -(xs) = -data(xs), Δ -> (-Δ,)
Base.transpose(xs::TrackedArray) = track(transpose, xs) Base.transpose(xs::TrackedArray) = track(transpose, xs)
Base.ctranspose(xs::TrackedArray) = track(ctranspose, xs) Base.ctranspose(xs::TrackedArray) = track(ctranspose, xs)
@grad transpose(xs) = xs.', Δ -> (trim(xs, Δ.'),) @grad transpose(xs) = data(xs).', Δ -> (trim(xs, Δ.'),)
@grad ctranspose(xs) = xs', Δ -> (trim(xs, Δ'),) @grad ctranspose(xs) = data(xs)', Δ -> (trim(xs, Δ'),)
Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...) Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...)
Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...) Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...)
@grad function repmat(xs, m, n = 1) @grad function repmat(xs, m, n = 1)
repmat(xs, m, n), function (Δ) repmat(data(xs), m, n), function (Δ)
Δ′ = similar(xs) Δ′ = similar(xs)
S = size(xs) S = size(xs)
for (i,v) in enumerate(Δ) for (i,v) in enumerate(Δ)
@ -101,7 +101,7 @@ end
Base.repeat(A::TrackedArray; kw...) = track(repeat, A; kw...) Base.repeat(A::TrackedArray; kw...) = track(repeat, A; kw...)
@grad function repeat(xs; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A))) @grad function repeat(xs; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A)))
repeat(xs, inner = inner, outer = outer), function (Δ) repeat(data(xs), inner = inner, outer = outer), function (Δ)
Δ′ = zero(xs) Δ′ = zero(xs)
S = size(xs) S = size(xs)
@ -143,7 +143,7 @@ for f in [:vcat, :hcat]
end end
@grad function vcat(xs...) @grad function vcat(xs...)
vcat(xs...), function (Δ) vcat(data.(xs)...), function (Δ)
start = 0 start = 0
Δs = [begin Δs = [begin
i = map(_ -> :, size(xsi)) |> Base.tail i = map(_ -> :, size(xsi)) |> Base.tail
@ -156,7 +156,7 @@ end
end end
@grad function hcat(xs...) @grad function hcat(xs...)
hcat(xs...), function (Δ) hcat(data.(xs)...), function (Δ)
start = 0 start = 0
Δs = [begin Δs = [begin
d = if ndims(xsi) == 1 d = if ndims(xsi) == 1
@ -176,7 +176,7 @@ Base.cat(dims, a::TrackedArray, b::AbstractArray...) = track(cat, dims, a, b...)
Base.cat(dims, a::Union{RowVector,Array}, b::TrackedArray, c::AbstractArray...) = track(cat, dims, a, b, c...) Base.cat(dims, a::Union{RowVector,Array}, b::TrackedArray, c::AbstractArray...) = track(cat, dims, a, b, c...)
@grad function cat(dims, Xs...) @grad function cat(dims, Xs...)
cat(dims, Xs...), function (Δ) cat(dims, data.(Xs)...), function (Δ)
start = ntuple(i -> 0, Val{ndims(Δ)}) start = ntuple(i -> 0, Val{ndims(Δ)})
Δs = [begin Δs = [begin
dim_xs = 1:ndims(xs) dim_xs = 1:ndims(xs)
@ -219,15 +219,15 @@ Base.sum(xs::TrackedArray, dim) = track(sum, xs, dim)
Base.sum(xs::TrackedArray) = track(sum, xs) Base.sum(xs::TrackedArray) = track(sum, xs)
Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs)) Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs))
@grad sum(xs, dim...) = sum(xs, dim...), @grad sum(xs, dim...) = sum(data(xs), dim...),
Δ -> (similar(xs) .= Δ, map(_->nothing,dim)...) Δ -> (similar(xs) .= Δ, map(_->nothing,dim)...)
Base.prod(xs::TrackedArray, dim) = track(prod, xs, dim) Base.prod(xs::TrackedArray, dim) = track(prod, xs, dim)
Base.prod(xs::TrackedArray) = track(prod, xs) Base.prod(xs::TrackedArray) = track(prod, xs)
Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs)) Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs))
@grad prod(xs) = prod(xs), Δ -> (similar(xs) .= (prod(xs) ./ xs) .* Δ,) @grad prod(xs) = prod(data(xs)), Δ -> (similar(xs) .= (prod(xs) ./ xs) .* Δ,)
@grad prod(xs, dim) = prod(xs, dim), @grad prod(xs, dim) = prod(data(xs), dim),
Δ -> (similar(xs) .= (reshape(.*(circshift.([reshape(xs, length(xs))], 1:length(xs)-1)...), size(xs))) .* Δ,nothing) Δ -> (similar(xs) .= (reshape(.*(circshift.([reshape(xs, length(xs))], 1:length(xs)-1)...), size(xs))) .* Δ,nothing)
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...) Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
@ -244,7 +244,7 @@ LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys)
LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys) LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys)
LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys) LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
@grad dot(xs, ys) = dot(xs, ys), Δ -> (Δ .* ys, Δ .* xs) @grad dot(xs, ys) = dot(data(xs), data(ys)), Δ -> (Δ .* ys, Δ .* xs)
# Hacks to get std working # Hacks to get std working
Base.std(x::TrackedArray; mean = Base.mean(x)) = Base.std(x::TrackedArray; mean = Base.mean(x)) =
@ -255,11 +255,11 @@ Base.std(x::TrackedArray, dim; mean = Base.mean(x, dim)) =
Base.vecnorm(x::TrackedArray, p::Real = 2) = Base.vecnorm(x::TrackedArray, p::Real = 2) =
sum(abs.(x).^p .+ eps(0f0))^(1/p) # avoid d(sqrt(x))/dx == Inf at 0 sum(abs.(x).^p .+ eps(0f0))^(1/p) # avoid d(sqrt(x))/dx == Inf at 0
@grad mean(xs) = mean(xs), Δ -> (similar(xs) .= Δ ./ length(xs),) @grad mean(xs) = mean(data(xs)), Δ -> (similar(xs) .= Δ ./ length(xs),)
@grad mean(xs, region) = mean(xs, region), Δ -> (similar(xs) .= Δ ./ prod(size(xs, region...)),nothing) @grad mean(xs, region) = mean(data(xs), region), Δ -> (similar(xs) .= Δ ./ prod(size(xs, region...)),nothing)
@grad function maximum(xs, r...) @grad function maximum(xs, r...)
maximum(xs, r...), function (Δ) maximum(data(xs), r...), function (Δ)
Δ′ = zeros(xs) Δ′ = zeros(xs)
_, i = findmax(xs, r...) _, i = findmax(xs, r...)
Δ′[i] = Δ Δ′[i] = Δ
@ -267,7 +267,7 @@ Base.vecnorm(x::TrackedArray, p::Real = 2) =
end end
end end
@grad function minimum(xs, r...) @grad function minimum(xs, r...)
minimum(xs, r...), function (Δ) minimum(data(xs), r...), function (Δ)
Δ′ = zeros(xs) Δ′ = zeros(xs)
_, i = findmin(xs, r...) _, i = findmin(xs, r...)
Δ′[i] = Δ Δ′[i] = Δ
@ -278,7 +278,7 @@ end
# BLAS # BLAS
Base.diagm(x::TrackedVector) = track(diagm, x) Base.diagm(x::TrackedVector) = track(diagm, x)
@grad diagm(x) = diagm(x), Δ -> (diag(Δ),) @grad diagm(x) = diagm(data(x)), Δ -> (diag(Δ),)
for f in :[*, Ac_mul_B, A_mul_Bc].args for f in :[*, Ac_mul_B, A_mul_Bc].args
@eval begin @eval begin
@ -298,10 +298,10 @@ for f in :[*, Ac_mul_B, A_mul_Bc].args
end end
@grad a::AbstractMatrix * b::AbstractVecOrMat = @grad a::AbstractMatrix * b::AbstractVecOrMat =
a*b, Δ -> (A_mul_Bt(Δ, b), At_mul_B(a, Δ)) data(a)*data(b), Δ -> (A_mul_Bt(Δ, b), At_mul_B(a, Δ))
@grad Ac_mul_B(a, b) = Ac_mul_B(a, b), Δ -> (A_mul_Bt(Δ, b)', a*Δ) @grad Ac_mul_B(a, b) = Ac_mul_B(data(a), data(b)), Δ -> (A_mul_Bt(Δ, b)', a*Δ)
@grad A_mul_Bc(a, b) = A_mul_Bc(a, b), Δ -> (Δ * b, At_mul_B(a, Δ)') @grad A_mul_Bc(a, b) = A_mul_Bc(data(a), data(b)), Δ -> (Δ * b, At_mul_B(a, Δ)')
# NNlib # NNlib
@ -310,32 +310,32 @@ import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, maxpool, mea
softmax(xs::TrackedArray) = track(softmax, xs) softmax(xs::TrackedArray) = track(softmax, xs)
@grad softmax(xs) = softmax(xs), Δ -> (∇softmax(Δ, xs),) @grad softmax(xs) = softmax(data(xs)), Δ -> (∇softmax(Δ, xs),)
logsoftmax(xs::TrackedArray) = track(logsoftmax, xs) logsoftmax(xs::TrackedArray) = track(logsoftmax, xs)
@grad logsoftmax(xs) = logsoftmax(xs), Δ -> (∇logsoftmax(Δ, xs),) @grad logsoftmax(xs) = logsoftmax(data(xs)), Δ -> (∇logsoftmax(Δ, xs),)
conv(x::TrackedArray, w::TrackedArray; kw...) = track(conv, x, w; kw...) conv(x::TrackedArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
conv(x::AbstractArray, w::TrackedArray; kw...) = track(conv, x, w; kw...) conv(x::AbstractArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...) conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...)
@grad conv(x, w; kw...) = @grad conv(x, w; kw...) =
conv(x, w; kw...), conv(data(x), data(w); kw...),
Δ -> (NNlib.∇conv_data(Δ, x, w; kw...), Δ -> (NNlib.∇conv_data(Δ, x, w; kw...),
NNlib.∇conv_filter(Δ, x, w; kw...)) NNlib.∇conv_filter(Δ, x, w; kw...))
maxpool(x::TrackedArray, k; kw...) = track(maxpool, x, k; kw...) maxpool(x::TrackedArray, k; kw...) = track(maxpool, x, k; kw...)
@grad function maxpool(x, k; kw...) @grad function maxpool(x, k; kw...)
y = maxpool(x, k; kw...) y = maxpool(data(x), k; kw...)
y, Δ -> (NNlib.∇maxpool(Δ, y, x, k; kw...), nothing) y, Δ -> (NNlib.∇maxpool(Δ, y, x, k; kw...), nothing)
end end
meanpool(x::TrackedArray, k; kw...) = track(meanpool, x, k; kw...) meanpool(x::TrackedArray, k; kw...) = track(meanpool, x, k; kw...)
@grad function meanpool(x, k; kw...) @grad function meanpool(x, k; kw...)
y = meanpool(x, k; kw...) y = meanpool(data(x), k; kw...)
y, Δ -> (NNlib.∇meanpool(Δ, y, x, k; kw...), nothing) y, Δ -> (NNlib.∇meanpool(Δ, y, x, k; kw...), nothing)
end end

View File

@ -78,6 +78,8 @@ end
Grads() = Grads(ObjectIdDict()) Grads() = Grads(ObjectIdDict())
Grads(ps::Params) = Grads(ObjectIdDict(tracker(p) => init_grad(data(p)) for p in ps))
Base.getindex(g::Grads, x::Tracked) = g.grads[x] Base.getindex(g::Grads, x::Tracked) = g.grads[x]
function Base.getindex(g::Grads, x) function Base.getindex(g::Grads, x)
istracked(x) || error("Object not tracked: $x") istracked(x) || error("Object not tracked: $x")
@ -114,15 +116,11 @@ back(::Grads, ::Void, _) = return
function forward(f, ps::Params) function forward(f, ps::Params)
y = f() y = f()
y, function (Δ) y, function (Δ)
g = Grads() g = Grads(ps)
if istracked(y) if istracked(y)
scan(y) scan(y)
back(g, tracker(y), Δ) back(g, tracker(y), Δ)
end end
for p in ps
haskey(g, tracker(p)) ||
(g[tracker(p)] = init_grad(data(p)))
end
return g return g
end end
end end

View File

@ -50,17 +50,17 @@ for (M, f, arity) in DiffRules.diffrules()
arity == 1 || continue arity == 1 || continue
@eval begin @eval begin
@grad $M.$f(a::Real) = @grad $M.$f(a::Real) =
$M.$f(a), Δ -> (Δ * $(DiffRules.diffrule(M, f, :(data(a)))),) $M.$f(data(a)), Δ -> (Δ * $(DiffRules.diffrule(M, f, :a)),)
$M.$f(a::TrackedReal) = track($M.$f, a) $M.$f(a::TrackedReal) = track($M.$f, a)
end end
end end
for (M, f, arity) in DiffRules.diffrules() for (M, f, arity) in DiffRules.diffrules()
arity == 2 || continue arity == 2 || continue
da, db = DiffRules.diffrule(M, f, :(data(a)), :(data(b))) da, db = DiffRules.diffrule(M, f, :a, :b)
f = :($M.$f) f = :($M.$f)
@eval begin @eval begin
@grad $f(a::Real, b::Real) = $f(a, b), Δ -> (Δ * $da, Δ * $db) @grad $f(a::Real, b::Real) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db)
$f(a::TrackedReal, b::TrackedReal) = track($f, a, b) $f(a::TrackedReal, b::TrackedReal) = track($f, a, b)
$f(a::TrackedReal, b::Real) = track($f, a, b) $f(a::TrackedReal, b::Real) = track($f, a, b)
$f(a::Real, b::TrackedReal) = track($f, a, b) $f(a::Real, b::TrackedReal) = track($f, a, b)