Merge pull request #313 from FluxML/ad-overhaul

AD Overhaul
This commit is contained in:
Mike J Innes 2018-07-11 15:33:02 +01:00 committed by GitHub
commit 6d8e6c0440
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 568 additions and 377 deletions

View File

@ -1,5 +1,4 @@
julia 0.6.0 julia 0.6.0
DataFlow 0.2.1
Juno Juno
MacroTools 0.3.3 MacroTools 0.3.3
NNlib NNlib

View File

@ -6,6 +6,52 @@ Backpropagation, or reverse-mode automatic differentiation, is handled by the `F
julia> using Flux.Tracker julia> using Flux.Tracker
``` ```
Here we discuss some more advanced uses of this module, as well as covering its internals.
## Taking Gradients
In the [basics section](../models/basics.md) we covered basic usage of the `gradient` function.
```julia
using Flux.Tracker
Tracker.gradient((a, b) -> a*b, 2, 3) # (3.0 (tracked), 2.0 (tracked))
```
`gradient` is actually just a thin wrapper around the backpropagator-based interface, `forward`.
```julia
using Flux.Tracker: forward
y, back = forward((a, b) -> a*b, 2, 3) # (6.0 (tracked), Flux.Tracker.#9)
back(1) # (3.0 (tracked), 2.0 (tracked))
```
The `forward` function returns two results. The first, `y`, is the original value of the function (perhaps with tracking applied). The second, `back`, is a new function which, given a sensitivity, returns the sensitivity of the inputs to `forward` (we call this a "backpropagator"). One use of this interface is to provide custom sensitivities when outputs are not scalar.
```julia
julia> y, back = forward((a, b) -> a.*b, [1,2,3],[4,5,6])
(param([4.0, 10.0, 18.0]), Flux.Tracker.#9)
julia> back([1,1,1])
(param([4.0, 5.0, 6.0]), param([1.0, 2.0, 3.0]))
```
We can also take gradients in-place. This can be useful if you only care about first-order gradients.
```julia
a, b = param(2), param(3)
c = a*b # 6.0 (tracked)
Tracker.back!(c)
Tracker.grad(a), Tracker.grad(b) # (3.0, 2.0)
```
## Tracked Arrays
The `param` function converts a normal Julia array into a new object that, while behaving like an array, tracks extra information that allows us to calculate derivatives. For example, say we multiply two parameters: The `param` function converts a normal Julia array into a new object that, while behaving like an array, tracks extra information that allows us to calculate derivatives. For example, say we multiply two parameters:
```julia ```julia
@ -41,7 +87,48 @@ julia> x.grad
-2.0 -2.0
``` ```
## Internals You may sometimes want to drop derivative information and just get the plain value back. You can do this by calling `Tracker.data(W)`.
## Custom Gradients
We can hook in to the processes above to implement custom gradients for a function or kernel. For a toy example, imagine a custom implementation of `minus`:
```julia
minus(a, b) = a - b
```
Firstly, we must tell the tracker system to stop when it sees a call to `minus`, and record it. We can do this using dispatch:
```julia
using Flux.Tracker: TrackedReal, track, @grad
minus(a::TrackedArray, b::TrackedArray) = Tracker.track(minus, a, b)
```
`track` takes care of building a new `Tracked` object and recording the operation on the tape. We just need to provide a gradient definition.
```julia
@grad function minus(a, b)
return minus(data(a),data(b)), Δ -> (Δ, -Δ)
end
```
This is essentially just a way of overloading the `forward` function we saw above. We strip tracking from `a` and `b` so that we are calling the original definition of `minus` (otherwise, we'd just try to track the call again and hit an infinite regress).
Note that in the backpropagator we don't call `data(a)`; we *do* in fact want to track this, since nest AD will take a derivative through the backpropagator itself. For example, the gradient of `*` might look like this.
```julia
@grad a * b = data(a)*data(b), Δ -> (Δ*b, a*Δ)
```
For multi-argument functions with custom gradients, you likely want to catch not just `minus(::TrackedArray, ::TrackedArray)` but also `minus(::Array, TrackedArray)` and so on. To do so, just define those extra signatures as needed:
```julia
minus(a::AbstractArray, b::TrackedArray) = Tracker.track(minus, a, b)
minus(a::TrackedArray, b::AbstractArray) = Tracker.track(minus, a, b)
```
## Tracked Internals
All `Tracked*` objects (`TrackedArray`, `TrackedReal`) are light wrappers around the `Tracked` type, which you can access via the `.tracker` field. All `Tracked*` objects (`TrackedArray`, `TrackedReal`) are light wrappers around the `Tracked` type, which you can access via the `.tracker` field.
@ -50,14 +137,9 @@ julia> x.tracker
Flux.Tracker.Tracked{Array{Float64,1}}(0x00000000, Flux.Tracker.Call{Void,Tuple{}}(nothing, ()), true, [5.0, 6.0], [-2.0, -2.0]) Flux.Tracker.Tracked{Array{Float64,1}}(0x00000000, Flux.Tracker.Call{Void,Tuple{}}(nothing, ()), true, [5.0, 6.0], [-2.0, -2.0])
``` ```
The `Tracker` stores the value and gradient of a given object, which we've seen before. The `Tracker` stores the gradient of a given object, which we've seen before.
```julia ```julia
julia> x.tracker.data
2-element Array{Float64,1}:
5.0
6.0
julia> x.tracker.grad julia> x.tracker.grad
2-element Array{Float64,1}: 2-element Array{Float64,1}:
-2.0 -2.0
@ -86,71 +168,4 @@ When we call `back!(y, [1, -1])`, the sensitivities `[1, -1]` simply get forward
Tracker.back(*, [1, -1], W, x) Tracker.back(*, [1, -1], W, x)
``` ```
which in turn calculates the sensitivities of the arguments (`W` and `x`) and backpropagates through their calls. This is recursive, so it will walk the entire program graph and propagate gradients to the original model parameters. which in turn calculates the sensitivities of the arguments (`W` and `x`) and back-propagates through their calls. This is recursive, so it will walk the entire program graph and propagate gradients to the original model parameters.
## Custom Gradients
We can hook in to the processes above to implement custom gradients for a function or kernel. For a toy example, imagine a custom implementation of `minus`:
```julia
julia> minus(a, b) = a - b
```
Firstly, we must tell the tracker system to stop when it sees a call to `minus`, and record it. We can do this using dispatch:
```julia
julia> minus(a::TrackedArray, b::TrackedArray) = Tracker.track(minus, a, b)
minus (generic function with 2 methods)
```
`Tracker.track` does two things: (1) it makes sure `minus` is called with *normal* array, not tracked ones (you can use `@show` inside `minus` to verify this), and (2) it uses the result to add a `minus` node to the tape. Look inside the result of calling `minus` to see what happened:
```julia
julia> a, b = param([6,5,4]), param([1,2,3])
(param([6.0, 5.0, 4.0]), param([1.0, 2.0, 3.0]))
julia> c = minus(a, b)
Tracked 3-element Array{Float64,1}:
5.0
3.0
1.0
julia> c.tracker.f
Flux.Tracker.Call{...}(minus, (param([6.0, 5.0, 4.0]), param([1.0, 2.0, 3.0])))
```
Finally, we have to specify the gradient of `minus`.
```julia
julia> Tracker.back(::typeof(minus), Δ, a, b) =
(Tracker.@back(a, Δ); Tracker.@back(b, -Δ))
```
`@back(x, Δ)` tells the tracker to continue propagating the sensitivity `Δ` through `x`. Now, AD will work with any program that calls `minus`.
```julia
julia> Flux.back!(c, 1)
julia> a.grad
3-element Array{Float64,1}:
1.0
1.0
1.0
julia> b.grad
3-element Array{Float64,1}:
-1.0
-1.0
-1.0
```
## Notes
For multi-argument functions with custom gradients, you likely want to catch not just `minus(::TrackedArray, ::TrackedArray)` but also `minus(::Array, TrackedArray)` and so on. To do so, just define those extra signatures as needed:
```julia
minus(a::AbstractArray, b::TrackedArray) = Tracker.track(minus, a, b)
minus(a::TrackedArray, b::AbstractArray) = Tracker.track(minus, a, b)
```
`@back` *must* be called exactly once on each tracked input argument. You do not need to do any special handling if one of the arguments is not tracked, as `@back` will just become a no-op.

View File

@ -2,20 +2,74 @@
## Taking Gradients ## Taking Gradients
Consider a simple linear regression, which tries to predict an output array `y` from an input `x`. (It's a good idea to follow this example in the Julia repl.) Flux's core feature is taking gradients of Julia code. The `gradient` function takes another Julia function `f` and a set of arguments, and returns the gradient with respect to each argument. (It's a good idea to try pasting these examples in the Julia terminal.)
```julia
using Flux.Tracker
f(x) = 3x^2 + 2x + 1
# df/dx = 6x + 2
f(x) = Tracker.gradient(f, x)[1]
f(2) # 14.0 (tracked)
# d²f/dx² = 6
f(x) = Tracker.gradient(f, x)[1]
f(2) # 6.0 (tracked)
```
(We'll learn more about why these numbers show up as `(tracked)` below.)
When a function has many parameters, we can pass them all in explicitly:
```julia
f(W, b, x) = W * x + b
Tracker.gradient(f, 2, 3, 4)
(4.0 (tracked), 1.0, 2.0 (tracked))
```
But machine learning models can have *hundreds* of parameters! Flux offers a nice way to handle this. We can tell Flux to treat something as a parameter via `param`. Then we can collect these together and tell `gradient` to collect the gradients of all of them at once.
```julia
W = param(2) # 2.0 (tracked)
b = param(3) # 3.0 (tracked)
f(x) = W * x + b
params = Params([W, b])
grads = Tracker.gradient(() -> f(4), params)
grads[W] # 4.0
grads[b] # 1.0
```
There are a few things to notice here. Firstly, `W` and `b` now show up as *tracked*. Tracked things behave like normal numbers or arrays, but keep records of everything you do with them, allowing Flux to calculate their gradients. `gradient` takes a zero-argument function; no arguments are necessary because the `Params` tell it what to differentiate.
This will come in really handy when dealing with big, complicated models. For now, though, let's start with something simple.
## Simple Models
Consider a simple linear regression, which tries to predict an output array `y` from an input `x`.
```julia ```julia
W = rand(2, 5) W = rand(2, 5)
b = rand(2) b = rand(2)
predict(x) = W*x .+ b predict(x) = W*x .+ b
loss(x, y) = sum((predict(x) .- y).^2)
function loss(x, y)
ŷ = predict(x)
sum((y .- ŷ).^2)
end
x, y = rand(5), rand(2) # Dummy data x, y = rand(5), rand(2) # Dummy data
loss(x, y) # ~ 3 loss(x, y) # ~ 3
``` ```
To improve the prediction we can take the gradients of `W` and `b` with respect to the loss function and perform gradient descent. We could calculate gradients by hand, but Flux will do it for us if we tell it that `W` and `b` are trainable *parameters*. To improve the prediction we can take the gradients of `W` and `b` with respect to the loss and perform gradient descent. Let's tell Flux that `W` and `b` are parameters, just like we did above.
```julia ```julia
using Flux.Tracker using Flux.Tracker
@ -23,17 +77,15 @@ using Flux.Tracker
W = param(W) W = param(W)
b = param(b) b = param(b)
l = loss(x, y) gs = Tracker.gradient(() -> loss(x, y), Params([W, b]))
back!(l)
``` ```
`loss(x, y)` returns the same number, but it's now a *tracked* value that records gradients as it goes along. Calling `back!` then accumulates the gradient of `W` and `b`. We can see what this gradient is, and modify `W` to train the model. Now that we have gradients, we can pull them out and update `W` to train the model. The `update!(W, Δ)` function applies `W = W + Δ`, which we can use for gradient descent.
```julia ```julia
using Flux.Tracker: grad, update! using Flux.Tracker: update!
Δ = grad(W) Δ = gs[W]
# Update the parameter and reset the gradient # Update the parameter and reset the gradient
update!(W, -0.1Δ) update!(W, -0.1Δ)
@ -43,7 +95,7 @@ loss(x, y) # ~ 2.5
The loss has decreased a little, meaning that our prediction `x` is closer to the target `y`. If we have some data we can already try [training the model](../training/training.md). The loss has decreased a little, meaning that our prediction `x` is closer to the target `y`. If we have some data we can already try [training the model](../training/training.md).
All deep learning in Flux, however complex, is a simple generalisation of this example. Of course, models can *look* very different they might have millions of parameters or complex control flow, and there are ways to manage this complexity. Let's see what that looks like. All deep learning in Flux, however complex, is a simple generalisation of this example. Of course, models can *look* very different they might have millions of parameters or complex control flow. Let's see how Flux handles more complex models.
## Building Layers ## Building Layers

View File

@ -103,7 +103,7 @@ m.(seq)
## Truncating Gradients ## Truncating Gradients
By default, calculating the gradients in a recurrent layer involves the entire history. For example, if we call the model on 100 inputs, calling `back!` will calculate the gradient for those 100 calls. If we then calculate another 10 inputs we have to calculate 110 gradients this accumulates and quickly becomes expensive. By default, calculating the gradients in a recurrent layer involves its entire history. For example, if we call the model on 100 inputs, we'll have to calculate the gradient for those 100 calls. If we then calculate another 10 inputs we have to calculate 110 gradients this accumulates and quickly becomes expensive.
To avoid this we can *truncate* the gradient calculation, forgetting the history. To avoid this we can *truncate* the gradient calculation, forgetting the history.

View File

@ -3,6 +3,8 @@
Consider a [simple linear regression](../models/basics.md). We create some dummy data, calculate a loss, and backpropagate to calculate gradients for the parameters `W` and `b`. Consider a [simple linear regression](../models/basics.md). We create some dummy data, calculate a loss, and backpropagate to calculate gradients for the parameters `W` and `b`.
```julia ```julia
using Flux.Tracker
W = param(rand(2, 5)) W = param(rand(2, 5))
b = param(rand(2)) b = param(rand(2))
@ -11,7 +13,9 @@ loss(x, y) = sum((predict(x) .- y).^2)
x, y = rand(5), rand(2) # Dummy data x, y = rand(5), rand(2) # Dummy data
l = loss(x, y) # ~ 3 l = loss(x, y) # ~ 3
back!(l)
params = Params([W, b])
grads = Tracker.gradient(() -> loss(x, y), params)
``` ```
We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that: We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that:
@ -22,7 +26,7 @@ using Flux.Tracker: grad, update!
function sgd() function sgd()
η = 0.1 # Learning Rate η = 0.1 # Learning Rate
for p in (W, b) for p in (W, b)
update!(p, -η * grad(p)) update!(p, -η * grads[p])
end end
end end
``` ```

View File

@ -286,41 +286,28 @@ function desc(rnn)
return d return d
end end
import Flux.Tracker: data, isleaf, istracked, track, back_, @back, unbroadcast import Flux.Tracker
import Flux.Tracker: data, istracked, track, unbroadcast, @grad, nobacksies
mutable struct RNNCall{R}
rnn::R
reserve::CuVector{UInt8}
RNNCall{R}(rnn::R) where R = new(rnn)
end
RNNCall(rnn) = RNNCall{typeof(rnn)}(rnn)
function (c::RNNCall)(args...)
rs, result = forwardTrain(desc(c.rnn), args...)
c.reserve = rs
return result
end
istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...)) istrain(m::CuRNNs, args...) = any(x -> x isa TrackedArray, (m.Wi, m.Wh, m.b, args...))
function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64} function (m::CuRNN{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
result = istrain(m, h, x) ? result = istrain(m, h, x) ?
track(RNNCall(m), x, h) : track(m, x, h, m.Wi, m.Wh, m.b) :
forward(desc(m), x, h) forward(desc(m), x, h)
return result[2], result[1] return result[2], result[1]
end end
function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64} function (m::CuGRU{T})(h::CuParam{T}, x::CuParam{T}) where T <: Union{Float32,Float64}
result = istrain(m, h, x) ? result = istrain(m, h, x) ?
track(RNNCall(m), x, h) : track(m, x, h, m.Wi, m.Wh, m.b) :
forward(desc(m), x, h) forward(desc(m), x, h)
return result[2], result[1] return result[2], result[1]
end end
function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64} function (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x::CuParam{T}) where T <: Union{Float32,Float64}
result = istrain(m, h, x) ? result = istrain(m, h, x) ?
track(RNNCall(m), x, h[1], h[2]) : track(m, x, h[1], h[2], m.Wi, m.Wh, m.b) :
forward(desc(m), x, h[1], h[2]) forward(desc(m), x, h[1], h[2])
return (result[2], result[3]), result[1] return (result[2], result[3]), result[1]
end end
@ -329,44 +316,29 @@ end
(m::CuGRU{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) (m::CuGRU{T})(h::CuParam{T}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
(m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x)) (m::CuLSTM{T})(h::NTuple{2,CuParam{T}}, x) where T <: Union{Float32,Float64} = m(h, CuArray{T}(x))
function accum_transpose!(dst::CuArray, src::CuArray) @grad function (m::Union{CuRNN,CuGRU})(x, h, Wi, Wh, b)
function kernel(dst, src) reserve, result = forwardTrain(desc(m), data(x), data(h))
I = @cuindex dst result, function (Δ)
dst[I...] += src[reverse(I)...] y, ho = result
return
end
blk, thr = cudims(dst)
@cuda (blk, thr) kernel(dst, src)
return dst
end
function back_(m::RNNCall{<:Union{CuRNN,CuGRU}}, y_, Δ, x, h)
y, ho = y_
dy, dho = Δ dy, dho = Δ
h_ = hBatch(x, data(h)) h_ = hBatch(x, data(h))
dx, dh = backwardData(descs[m.rnn], y, dy, dho, h_, m.reserve) dx, dh = backwardData(descs[m], y, dy, dho, h_, reserve)
@back(x, dx) (dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
@back(h, unbroadcast(h, dh)) nobacksies(:RNN, (dx, unbroadcast(size(h), dh), dWi.', dWh.', db))
(dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y, m.reserve) end
# We don't have to make this assumption, it's just slightly more complex.
@assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b)))
istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi)
istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh)
istracked(m.rnn.b) && accum_transpose!(m.rnn.b.grad, db)
end end
function back_(m::RNNCall{<:CuLSTM}, y_, Δ, x, h, c) @grad function (m::CuLSTM)(x, h, c, Wi, Wh, b)
y, ho, co = y_ reserve, result = forwardTrain(desc(m), data.((x, h, c))...)
result, function (Δ)
y, ho = result
dy, dho, dco = Δ dy, dho, dco = Δ
h_ = hBatch(x, data(h)) h_ = hBatch(x, data(h))
c_ = hBatch(x, data(c)) c_ = hBatch(x, data(c))
dx, dh, dc = backwardData(descs[m.rnn], y, dy, dho, dco, h_, c_, m.reserve) dx, dh, dc = backwardData(descs[m], y, dy, dho, dco, h_, c_, reserve)
@back(x, dx) (dWi, dWh), db = backwardWeights(descs[m], data(x), h_, y, reserve)
@back(h, unbroadcast(h, dh)) nobacksies(:RNN,
@back(c, unbroadcast(h, dc)) (dx, unbroadcast(size(h), dh), unbroadcast(size(c), dc),
(dWi, dWh), db = backwardWeights(descs[m.rnn], data(x), h_, y, m.reserve) dWi.', dWh.', db))
@assert all(isleaf.((m.rnn.Wi, m.rnn.Wh, m.rnn.b))) end
istracked(m.rnn.Wi) && accum_transpose!(m.rnn.Wi.grad, dWi)
istracked(m.rnn.Wh) && accum_transpose!(m.rnn.Wh.grad, dWh)
istracked(m.rnn.b) && accum_transpose!(m.rnn.b.grad, db)
end end

View File

@ -1,23 +1,27 @@
module Tracker module Tracker
using MacroTools
using MacroTools: @q, @forward
import Base: == import Base: ==
export TrackedArray, TrackedVector, TrackedMatrix, param, back! export TrackedArray, TrackedVector, TrackedMatrix, Params, param, back!
tracker(x) = nothing tracker(x) = nothing
istracked(x) = tracker(x) nothing istracked(x) = tracker(x) nothing
isleaf(x) = !istracked(x) || isleaf(tracker(x)) isleaf(x) = !istracked(x) || isleaf(tracker(x))
data(x) = istracked(x) ? data(tracker(x)) : x
grad(x) = grad(tracker(x)) grad(x) = grad(tracker(x))
grad(::Void) = nothing grad(::Void) = nothing
data(x) = x
struct Call{F,As<:Tuple} struct Call{F,As<:Tuple}
func::F func::F
args::As args::As
end end
Call(f, args...) = Call{typeof(f),typeof(args)}(f, args) Call(f, args) = Call{typeof(f),typeof(args)}(f, args)
Call() = Call(nothing, ())
# When deserialising, the object_id changes # When deserialising, the object_id changes
a::Call == b::Call = a.func == b.func && a.args == b.args a::Call == b::Call = a.func == b.func && a.args == b.args
@ -28,31 +32,41 @@ mutable struct Tracked{T}
ref::UInt32 ref::UInt32
f::Call f::Call
isleaf::Bool isleaf::Bool
data::T
grad::T grad::T
Tracked{T}(f::Call, data::T) where T = new(0, f, false, data) Tracked{T}(f::Call) where T = new(0, f, false)
Tracked{T}(f::Call, data::T, grad::T) where T = new(0, f, false, data, grad) Tracked{T}(f::Call, grad::T) where T = new(0, f, false, grad)
Tracked{T}(f::Call{Void}, data::T, grad::T) where T = new(0, f, true, data, grad) Tracked{T}(f::Call{Void}, grad::T) where T = new(0, f, true, grad)
end end
Tracked(f::Call, x) = Tracked{typeof(x)}(f, x)
Tracked(f::Call, x, Δ) = Tracked{typeof(x)}(f, x, Δ)
track(f::Call, x) = Tracked(f, x)
track(f::Call) = track(f, f())
track(f, xs...) = track(Call(f, xs...))
istracked(x::Tracked) = true istracked(x::Tracked) = true
isleaf(x::Tracked) = x.f == Call(nothing) isleaf(x::Tracked) = x.f == Call()
data(x::Tracked) = x.data
grad(x::Tracked) = x.grad grad(x::Tracked) = x.grad
track(f::Call, x) = Tracked{typeof(x)}(f)
function _forward end
function track(f, xs...; kw...)
y, back = _forward(f, xs...; kw...)
track(Call(back, tracker.(xs)), y)
end
macro grad(ex)
@capture(shortdef(ex), (name_(args__) = body_) |
(name_(args__) where {T__} = body_)) || error("Need a function definition")
T == nothing && (T = [])
isexpr(name, :(::)) || (name = :(::typeof($name)))
insert!(args, 1+isexpr(args[1], :parameters) , name)
@q(Tracker._forward($(args...)) where $(T...) = $body) |> esc
end
function update!(x, Δ) function update!(x, Δ)
tracker(x).data += Δ x.data .+= data(Δ)
tracker(x).grad .= 0 tracker(x).grad .= 0
return x return x
end end
include("idset.jl")
include("back.jl") include("back.jl")
include("scalar.jl") include("scalar.jl")
include("array.jl") include("array.jl")
@ -66,11 +80,35 @@ Hook into gradient backpropagation. `x` is unmodified, but when backpropagating
the sign of the gradient applied to `x`. the sign of the gradient applied to `x`.
""" """
hook(f, x) = istracked(x) ? track(hook, f, x) : x hook(f, x) = istracked(x) ? track(hook, f, x) : x
back(::typeof(hook), Δ, f, x) = @back(x, f(Δ)) @grad hook(f, x) = x, Δ -> (nothing, f(Δ))
"""
checkpoint(f, args...)
Behaves like `f(args...)`, but avoids storing the intermediate values needed for
calculating gradients. Instead, `f(args...)` will be called again during the
backward pass. This can be used to save memory in larger models.
"""
checkpoint(f, args...) = track(checkpoint, f, args...)
@grad function checkpoint(f, args...)
data(f(args...)), function (Δ)
y, back = forward(f, args...)
(nothing, back(Δ)...)
end
end
nobacksies(f, x) = track(nobacksies, f, x)
nobacksies(f, xs::Tuple) = map(x -> nobacksies(f, x), xs)
@grad nobacksies(f, x) = data(x), Δ -> error("Nested AD not defined for $f")
param(x::Number) = TrackedReal(float(x)) param(x::Number) = TrackedReal(float(x))
param(xs::AbstractArray) = TrackedArray(float.(xs)) param(xs::AbstractArray) = TrackedArray(float.(xs))
@grad identity(x) = data(x), Δ -> (Δ,)
param(x::TrackedReal) = track(identity, x)
param(x::TrackedArray) = track(identity, x)
import NNlib.cudata import NNlib.cudata
import Adapt.adapt import Adapt.adapt

View File

@ -6,6 +6,7 @@ struct TrackedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N}
TrackedArray{T,N,A}(t::Tracked{A}, data::A, grad::A) where {T,N,A} = new(t, data, grad) TrackedArray{T,N,A}(t::Tracked{A}, data::A, grad::A) where {T,N,A} = new(t, data, grad)
end end
data(x::TrackedArray) = x.data
tracker(x::TrackedArray) = x.tracker tracker(x::TrackedArray) = x.tracker
TrackedVector{T,A} = TrackedArray{T,1,A} TrackedVector{T,A} = TrackedArray{T,1,A}
@ -15,12 +16,12 @@ TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}}
track(c::Call, x::AbstractArray) = TrackedArray(c, x) track(c::Call, x::AbstractArray) = TrackedArray(c, x)
TrackedArray(c::Call, x::A) where A <: AbstractArray = TrackedArray(c::Call, x::A) where A <: AbstractArray =
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, x), x) TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c), x)
TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray = TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray =
TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, x, Δ), x, Δ) TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, Δ), x, Δ)
TrackedArray(x::AbstractArray) = TrackedArray(Call(nothing), x, zeros(x)) TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zeros(x))
Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T} Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T}
@ -49,6 +50,9 @@ for f in :[Base.size, Base.ndims].args
@eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...) @eval @inline $f(x::TrackedArray, a...) = $f(data(x), a...)
end end
Base.size(x::TrackedArray, i::Integer, j::Integer, is::Integer...) =
size(data(x), i, j, is...)
Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) = Base.similar(x::TrackedArray, dims::Union{AbstractUnitRange,Integer}...) =
similar(data(x), dims...) similar(data(x), dims...)
@ -62,54 +66,57 @@ Base.:(==)(x::TrackedArray, y::TrackedArray) = data(x) == data(y)
Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...) Base.getindex(xs::TrackedArray, i...) = track(getindex, xs, i...)
function back(::typeof(getindex), Δ, xs::TrackedArray, i...) @grad function getindex(xs::AbstractArray, i...)
Δ′ = zeros(xs.data) data(xs)[i...], function (Δ)
Δ′[i...] = Δ Δ′ = zero(xs)
@back(xs, Δ′) Δ′[i...] = data(Δ)
(nobacksies(:getindex, Δ′), map(_->nothing, i)...)
end
end end
Base.:-(xs::TrackedArray) = track(-, xs) Base.:-(xs::TrackedArray) = track(-, xs)
back(::typeof(-), Δ, xs::TrackedArray) = back(xs, -Δ) @grad -(xs) = -data(xs), Δ -> (-Δ,)
Base.transpose(xs::TrackedArray) = track(transpose, xs) Base.transpose(xs::TrackedArray) = track(transpose, xs)
Base.ctranspose(xs::TrackedArray) = track(ctranspose, xs) Base.ctranspose(xs::TrackedArray) = track(ctranspose, xs)
back(::typeof(transpose), Δ, xs) = @back(xs, trim(xs, Δ.')) @grad transpose(xs) = data(xs).', Δ -> (reshape(Δ.', size(xs)),)
back(::typeof(ctranspose), Δ, xs) = @back(xs, trim(xs, Δ')) @grad ctranspose(xs) = data(xs)', Δ -> (reshape(Δ', size(xs)),)
Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...) Base.repmat(x::TrackedVecOrMat, a::Integer...) = track(repmat, x, a...)
Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...) Base.repmat(x::TrackedVecOrMat, a::Int64...) = track(repmat, x, a...)
function back(::typeof(repmat), Δ, xs::TrackedVecOrMat, m, n=1) @grad function repmat(xs, m, n = 1)
Δ′ = similar(xs.data) repmat(data(xs), m, n), function (Δ)
S = size(xs.data) Δ′ = similar(xs)
for (i,v) in enumerate(Δ) S = size(xs)
for (i,v) in enumerate(data(Δ))
d1 = divrem(i-1, S[1]*m) d1 = divrem(i-1, S[1]*m)
x = d1[2] % S[1]+1 x = d1[2] % S[1]+1
y = d1[1] % S[2]+1 y = d1[1] % S[2]+1
Δ′[x, y] += v Δ′[x, y] += v
end end
back(xs, Δ′) return (nobacksies(:repmat, Δ′), nothing, nothing)
end
end end
Base.repeat(A::TrackedArray; kw...) = track(repeat, A; kw...)
_repeat(A, inner, outer) = Base.repeat(A; inner=inner, outer=outer) @grad function repeat(xs; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A)))
Base.repeat(A::TrackedArray; inner=ntuple(x->1, ndims(A)), outer=ntuple(x->1, ndims(A))) = track(_repeat, A, inner, outer) repeat(data(xs), inner = inner, outer = outer), function (Δ)
Δ′ = zero(xs)
function back(::typeof(_repeat), Δ, xs::TrackedArray, inner, outer) S = size(xs)
Δ′ = similar(xs.data)
Δ′ .= 0
S = size(xs.data)
# Loop through each element of Δ, calculate source dimensions, accumulate into Δ′ # Loop through each element of Δ, calculate source dimensions, accumulate into Δ′
for (dest_idx, val) in enumerate(IndexCartesian(), Δ) for (dest_idx, val) in enumerate(IndexCartesian(), data(Δ))
# First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then # First, round dest_idx[dim] to nearest gridpoint defined by inner[dim], then
# wrap around based on original size S. # wrap around based on original size S.
src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)] src_idx = [mod1(div(dest_idx[dim] - 1, inner[dim]) + 1, S[dim]) for dim in 1:length(S)]
Δ′[src_idx...] += val Δ′[src_idx...] += val
end end
back(xs, Δ′) (nobacksies(:repeat, Δ′),)
end
end end
@ -138,42 +145,51 @@ for f in [:vcat, :hcat]
end end
end end
function back(::typeof(vcat), Δ, xs...) @grad function vcat(xs...)
vcat(data.(xs)...), function (Δ)
start = 0 start = 0
for xsi in xs Δs = [begin
i = map(_ -> :, size(xsi)) |> Base.tail i = map(_ -> :, size(xsi)) |> Base.tail
@back(xsi, Δ[start+1:start+size(xsi,1), i...]) d = Δ[start+1:start+size(xsi,1), i...]
start += size(xsi, 1) start += size(xsi, 1)
d
end for xsi in xs]
return (Δs...,)
end end
end end
function back(::typeof(hcat), Δ, xs...) @grad function hcat(xs...)
hcat(data.(xs)...), function (Δ)
start = 0 start = 0
for xsi in xs Δs = [begin
if ndims(xsi) == 1 d = if ndims(xsi) == 1
@back(xsi, Δ[:, start+1]) Δ[:, start+1]
else else
i = map(_ -> :, size(xsi)) |> Base.tail |> Base.tail i = map(_ -> :, size(xsi)) |> Base.tail |> Base.tail
@back(xsi, Δ[:, start+1:start+size(xsi,2), i...]) Δ[:, start+1:start+size(xsi,2), i...]
end end
start += size(xsi, 2) start += size(xsi, 2)
d
end for xsi in xs]
return (Δs...,)
end end
end end
Base.cat(dims, a::TrackedArray, b::AbstractArray...) = track(cat, dims, a, b...) Base.cat(dims, a::TrackedArray, b::AbstractArray...) = track(cat, dims, a, b...)
Base.cat(dims, a::Union{RowVector,Array}, b::TrackedArray, c::AbstractArray...) = track(cat, dims, a, b, c...) Base.cat(dims, a::Union{RowVector,Array}, b::TrackedArray, c::AbstractArray...) = track(cat, dims, a, b, c...)
function back(::typeof(cat), Δ, dims, Xs...) @grad function cat(dims, Xs...)
cat(dims, data.(Xs)...), function (Δ)
start = ntuple(i -> 0, Val{ndims(Δ)}) start = ntuple(i -> 0, Val{ndims(Δ)})
for xs in Xs Δs = [begin
dim_xs = 1:ndims(xs) dim_xs = 1:ndims(xs)
till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val{ndims(Δ)}) till_xs = ntuple((i -> i in dims ? (i in dim_xs ? size(xs,i) : 1) : 0), Val{ndims(Δ)})
xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val{ndims(Δ)}) xs_in_Δ = ntuple(i -> till_xs[i] > 0 ? (start[i]+1:start[i]+till_xs[i]) : Colon(), Val{ndims(Δ)})
d = reshape(Δ[xs_in_Δ...],size(xs))
@back(xs, reshape(Δ[xs_in_Δ...],size(xs)))
start = start .+ till_xs start = start .+ till_xs
d
end for xs in Xs]
return (nothing, Δs...,)
end end
end end
@ -181,11 +197,10 @@ Base.reshape(xs::TrackedArray, dims::Union{Colon,Int64}...) = reshape(xs, dims)
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Union{Int64,Colon}}}) = reshape(xs, Base._reshape_uncolon(xs, dims)) Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Union{Int64,Colon}}}) = reshape(xs, Base._reshape_uncolon(xs, dims))
Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64}}) = track(reshape, xs, dims) Base.reshape(xs::TrackedArray, dims::Tuple{Vararg{Int64}}) = track(reshape, xs, dims)
back(::typeof(reshape), Δ, xs::TrackedArray, _...) = @grad reshape(xs, dims) = reshape(data(xs), dims), Δ -> (reshape(Δ, size(xs)),nothing)
back(xs, reshape(Δ, size(xs)))
Base.permutedims(xs::TrackedArray, dims) = track(permutedims, xs, dims) Base.permutedims(xs::TrackedArray, dims) = track(permutedims, xs, dims)
back(::typeof(permutedims), Δ, xs::TrackedArray, dims) = back(xs, permutedims(Δ, invperm(dims))) @grad permutedims(xs, dims) = permutedims(data(xs), dims), Δ -> (permutedims(Δ, invperm(dims)),nothing)
function _kron(mat1::AbstractMatrix,mat2::AbstractMatrix) function _kron(mat1::AbstractMatrix,mat2::AbstractMatrix)
m1, n1 = size(mat1) m1, n1 = size(mat1)
@ -207,14 +222,18 @@ Base.sum(xs::TrackedArray, dim) = track(sum, xs, dim)
Base.sum(xs::TrackedArray) = track(sum, xs) Base.sum(xs::TrackedArray) = track(sum, xs)
Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs)) Base.sum(f::Union{Function,Type},xs::TrackedArray) = sum(f.(xs))
back(::typeof(sum), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= Δ) @grad sum(xs, dim...) = sum(data(xs), dim...),
Δ -> (zero(xs) .+ Δ, map(_->nothing,dim)...)
Base.prod(xs::TrackedArray, dim) = track(prod, xs, dim) Base.prod(xs::TrackedArray, dim) = track(prod, xs, dim)
Base.prod(xs::TrackedArray) = track(prod, xs) Base.prod(xs::TrackedArray) = track(prod, xs)
Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs)) Base.prod(f::Union{Function, Type}, xs::TrackedArray) = prod(f.(xs))
back(::typeof(prod), Δ, xs::TrackedArray, dim...) = back(xs, similar(xs.data) .= (prod(xs.data, dim...) ./ xs.data) .* Δ) @grad prod(xs) = prod(data(xs)), Δ -> (prod(xs) ./ xs .* Δ,)
back(::typeof(prod), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= (reshape(.*(circshift.([reshape(xs.data, length(xs.data))], 1:length(xs.data)-1)...), size(xs.data))) .* Δ) @grad prod(xs, dim) = prod(data(xs), dim),
Δ -> (nobacksies(:sum,
reshape(.*(circshift.([reshape(data(xs), length(xs))], 1:length(xs)-1)...), size(xs)) .* Δ),
nothing)
Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...) Base.findfirst(xs::TrackedArray, args...) = findfirst(xs.data, args...)
@ -230,10 +249,7 @@ LinAlg.dot(xs::TrackedVector, ys::TrackedVector) = track(dot, xs, ys)
LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys) LinAlg.dot(xs::AbstractVector, ys::TrackedVector) = track(dot, xs, ys)
LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys) LinAlg.dot(xs::TrackedVector, ys::AbstractVector) = track(dot, xs, ys)
function back(::typeof(dot), Δ, xs, ys) @grad dot(xs, ys) = dot(data(xs), data(ys)), Δ -> (Δ .* ys, Δ .* xs)
@back(xs, Δ.*data(ys))
@back(ys, Δ.*data(xs))
end
# Hacks to get std working # Hacks to get std working
Base.std(x::TrackedArray; mean = Base.mean(x)) = Base.std(x::TrackedArray; mean = Base.mean(x)) =
@ -244,41 +260,32 @@ Base.std(x::TrackedArray, dim; mean = Base.mean(x, dim)) =
Base.vecnorm(x::TrackedArray, p::Real = 2) = Base.vecnorm(x::TrackedArray, p::Real = 2) =
sum(abs.(x).^p .+ eps(0f0))^(1/p) # avoid d(sqrt(x))/dx == Inf at 0 sum(abs.(x).^p .+ eps(0f0))^(1/p) # avoid d(sqrt(x))/dx == Inf at 0
back(::typeof(mean), Δ, xs::TrackedArray) = back(xs, similar(xs.data) .= Δ ./ length(xs.data)) @grad mean(xs) = mean(data(xs)), Δ -> (Δ / length(xs),)
back(::typeof(mean), Δ, xs::TrackedArray, region) = @grad mean(xs, region) = mean(data(xs), region), Δ -> (zero(xs) .+ Δ ./ prod(size(xs, region...)),nothing)
back(xs, similar(xs.data) .= Δ ./ prod(size(xs.data, region...)))
function back(::typeof(maximum), Δ, xs::TrackedArray) @grad function maximum(xs, r...)
Δ′ = zeros(xs.data) maximum(data(xs), r...), function (Δ)
_, i = findmax(xs.data) Δ′ = zero(xs)
Δ′[i] = Δ _, i = findmax(data(xs), r...)
@back(xs, Δ′) Δ′[i] = data(Δ)
return (nobacksies(:maximum, Δ′),map(_->nothing,r)...)
end
end end
function back(::typeof(maximum), Δ, xs::TrackedArray, region) @grad function minimum(xs, r...)
Δ′ = zeros(xs.data) minimum(data(xs), r...), function (Δ)
_, is = findmax(xs.data, region) Δ′ = zero(xs)
Δ′[is] = Δ _, i = findmin(data(xs), r...)
@back(xs, Δ′) Δ′[i] = data(Δ)
end return (nobacksies(:minimum, Δ′),map(_->nothing,r)...)
function back(::typeof(minimum), Δ, xs::TrackedArray) end
Δ′ = zeros(xs.data)
_, i = findmin(xs.data)
Δ′[i] = Δ
@back(xs, Δ′)
end
function back(::typeof(minimum), Δ, xs::TrackedArray, region)
Δ′ = zeros(xs.data)
_, is = findmin(xs.data, region)
Δ′[is] = Δ
@back(xs, Δ′)
end end
# BLAS # BLAS
Base.diagm(x::TrackedVector) = track(diagm, x) Base.diagm(x::TrackedVector) = track(diagm, x)
back(::typeof(diagm), Δ, x) = @back(x, diag(Δ)) @grad diagm(x) = diagm(data(x)), Δ -> (diag(Δ),)
for f in :[*, Ac_mul_B, A_mul_Bc].args for f in :[*, Ac_mul_B, A_mul_Bc, A_mul_Bt, At_mul_B].args
@eval begin @eval begin
import Base.$f import Base.$f
$f(a::TrackedMatrix, b::TrackedMatrix) = track($f, a, b) $f(a::TrackedMatrix, b::TrackedMatrix) = track($f, a, b)
@ -295,30 +302,14 @@ for f in :[*, Ac_mul_B, A_mul_Bc].args
end end
end end
function back(::typeof(*), Δ, a::AbstractMatrix, b::AbstractVecOrMat) @grad a::AbstractMatrix * b::AbstractVecOrMat =
@back(a, A_mul_Bt(Δ, data(b))) data(a)*data(b), Δ -> (A_mul_Bt(Δ, b), At_mul_B(a, Δ))
@back(b, At_mul_B(data(a), Δ))
end
function back(::typeof(Ac_mul_B), Δ, a::AbstractVecOrMat{<:Real}, b::AbstractVecOrMat{<:Real}) @grad Ac_mul_B(a, b) = Ac_mul_B(data(a), data(b)), Δ -> (A_mul_Bt(Δ, b)', a*Δ)
@back(a, A_mul_Bt(Δ, data(b))') @grad A_mul_Bc(a, b) = A_mul_Bc(data(a), data(b)), Δ -> (Δ * b, At_mul_B(a, Δ)')
@back(b, data(a)*Δ)
end
function back(::typeof(A_mul_Bc), Δ, a::AbstractVecOrMat{<:Real}, b::AbstractVecOrMat{<:Real}) @grad At_mul_B(a, b) = At_mul_B(data(a), data(b)), Δ -> (A_mul_Bt(Δ, b)', a*Δ)
@back(a, Δ * data(b)) @grad A_mul_Bt(a, b) = A_mul_Bt(data(a), data(b)), Δ -> (Δ * b, At_mul_B(a, Δ)')
@back(b, At_mul_B(data(a), Δ)')
end
# Fast path for matrix-vector
function back(::typeof(*), Δ::AbstractVector, W::TrackedMatrix, x::AbstractVector)
if isleaf(W)
W.grad .+= Δ .* data(x).'
else
back(W, A_mul_Bt(Δ, data(x)))
end
@back(x, At_mul_B(data(W), Δ))
end
# NNlib # NNlib
@ -327,82 +318,69 @@ import NNlib: softmax, ∇softmax, logsoftmax, ∇logsoftmax, conv, maxpool, mea
softmax(xs::TrackedArray) = track(softmax, xs) softmax(xs::TrackedArray) = track(softmax, xs)
back(::typeof(softmax), Δ, xs) = @back(xs, ∇softmax(Δ, data(xs))) @grad softmax(xs) = softmax(data(xs)), Δ -> (nobacksies(:softmax, ∇softmax(data(Δ), data(xs))),)
logsoftmax(xs::TrackedArray) = track(logsoftmax, xs) logsoftmax(xs::TrackedArray) = track(logsoftmax, xs)
back(::typeof(logsoftmax), Δ, xs) = @back(xs, ∇logsoftmax(Δ, data(xs))) @grad logsoftmax(xs) = logsoftmax(data(xs)), Δ -> (nobacksies(:logsoftmax, ∇logsoftmax(data(Δ), data(xs))),)
# TODO: can store kwargs efficiently in namedtuples conv(x::TrackedArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
_conv(x, w, stride, pad, dilation) = conv(x, w, stride = stride, pad = pad, dilation = dilation) conv(x::AbstractArray, w::TrackedArray; kw...) = track(conv, x, w; kw...)
conv(x::TrackedArray, w::AbstractArray; kw...) = track(conv, x, w; kw...)
conv(x::TrackedArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N = @grad conv(x, w; kw...) =
track(_conv, x, w, stride, pad, dilation) conv(data(x), data(w); kw...),
conv(x::AbstractArray{<:Real,N}, w::TrackedArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N = Δ -> nobacksies(:conv,
track(_conv, x, w, stride, pad, dilation) (NNlib.∇conv_data(data.((Δ, x, w))...; kw...),
conv(x::TrackedArray{<:Real,N}, w::AbstractArray{<:Real,N}; stride = 1, pad = 0, dilation = 1) where N = NNlib.∇conv_filter(data.((Δ, x, w))...; kw...)))
track(_conv, x, w, stride, pad, dilation)
function back(::typeof(_conv), Δ, x, w, stride, pad, dilation) maxpool(x::TrackedArray, k; kw...) = track(maxpool, x, k; kw...)
@back(x, NNlib.∇conv_data(Δ, data(x), data(w); stride = stride, pad = pad, dilation = dilation))
@back(w, NNlib.∇conv_filter(Δ, data(x), data(w); stride = stride, pad = pad, dilation = dilation)) @grad function maxpool(x, k; kw...)
y = maxpool(data(x), k; kw...)
y, Δ -> (nobacksies(:maxpool, NNlib.∇maxpool(data.((Δ, y, x))..., k; kw...)), nothing)
end end
_maxpool(x, k, pad, stride) = maxpool(x, k; pad = pad, stride = stride) meanpool(x::TrackedArray, k; kw...) = track(meanpool, x, k; kw...)
maxpool(x::TrackedArray, k; pad = map(_->0,k), stride = k) = @grad function meanpool(x, k; kw...)
track(_maxpool, x, k, pad, stride) y = meanpool(data(x), k; kw...)
y, Δ -> (nobacksies(:maxpool, NNlib.∇meanpool(data.((Δ, y, x))..., k; kw...)), nothing)
back_(::typeof(_maxpool), y, Δ, x, k, pad, stride) = end
back(x, NNlib.∇maxpool(Δ, y, data(x), k, pad=pad, stride=stride))
_meanpool(x, k, pad, stride) = meanpool(x, k; pad = pad, stride = stride)
meanpool(x::TrackedArray, k; pad = map(_->0,k), stride = k) =
track(_meanpool, x, k, pad, stride)
back_(::typeof(_meanpool), y, Δ, x, k, pad, stride) =
back(x, NNlib.∇meanpool(Δ, y, data(x), k, pad=pad, stride=stride))
# Broadcasting # Broadcasting
using ForwardDiff: Dual, partials using ForwardDiff: Dual, partials, value
struct Broadcasted{F,T}
f::F
data::T
end
(b::Broadcasted)(xs...) = map(x -> x.value, b.data)
dualify(xs, n) = xs dualify(xs, n) = xs
dualify(xs::TrackedArray, ps) = map(x -> Dual(x, ps), data(xs)) dualify(xs::AbstractArray, ps) = map(x -> Dual(x, ps), xs)
dualify(xs::TrackedReal, ps) = Dual(data(xs), ps) dualify(xs::Real, ps) = Dual(xs, ps)
function tracked_broadcast(f, args::Vararg{Any,N}) where N unbroadcast(x::Tuple, Δ) =
dargs = map((x,i) -> dualify(x, ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N})) x == size(Δ) ? Δ :
out = broadcast(f, dargs...) reshape(sum(Δ, filter(n -> n > length(x) || x[n] == 1, 1:ndims(Δ))), x)
eltype(out) <: Dual || return out
b = Broadcasted(f, out)
track(Call(b, args...), b())
end
trim(x, Δ) = reshape(Δ, ntuple(i -> size(Δ, i), Val{ndims(x)})) unbroadcast(x::Tuple{}, Δ) = sum(Δ)
unbroadcast(x::AbstractArray, Δ) =
size(x) == size(Δ) ? Δ :
trim(x, sum(Δ, filter(n -> size(x, n) == 1, 1:ndims(Δ))))
unbroadcast(x::Number, Δ) = sum(Δ)
function getpartial(Δ, x, i) function getpartial(Δ, x, i)
@inbounds p = getindex(partials(x), i) @inbounds p = getindex(partials(x), i)
return Δ * p return Δ * p
end end
function back(b::Broadcasted, Δ, args::Vararg{Any,N}) where N function ∇broadcast(f, args::Vararg{Any,N}) where N
Δargs = ntuple(i -> getpartial.(Δ, b.data, i), Val{N}) sizes = size.(args)
foreach((x, Δ) -> @back(x, unbroadcast(x, Δ)), args, Δargs) dargs = map((x,i) -> dualify(data(x), ntuple(j -> i==j, Val{N})), args, ntuple(identity, Val{N}))
out = broadcast(f, dargs...)
eltype(out) <: Dual || return out
y = value.(out)
back = function (Δ_)
Δ = data(Δ_)
Δargs = ntuple(i -> getpartial.(Δ, out, i), Val{N})
dxs = map((x, Δ) -> unbroadcast(x, Δ), sizes, Δargs)
nobacksies(:broadcast, dxs)
end
# So we can return non-tracked arrays
track(Call(back, tracker.(args)), y)
end end
Base.Broadcast._containertype(::Type{<:TrackedReal}) = TrackedArray Base.Broadcast._containertype(::Type{<:TrackedReal}) = TrackedArray
@ -415,4 +393,4 @@ Base.Broadcast.promote_containertype(ct, ::Type{TrackedArray}) = TrackedArray
Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A::Ref) = () Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A::Ref) = ()
Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A) = indices(A) Base.Broadcast.broadcast_indices(::Type{TrackedArray}, A) = indices(A)
Base.Broadcast.broadcast_c(f, ::Type{TrackedArray}, A, Bs...) = tracked_broadcast(f, A, Bs...) Base.Broadcast.broadcast_c(f, ::Type{TrackedArray}, A, Bs...) = broadcast(f, A, Bs...)

View File

@ -10,8 +10,6 @@ function scan(x::Tracked)
if ref == 1 if ref == 1
scan(x.f) scan(x.f)
isdefined(x, :grad) && (x.grad = zero_grad!(x.grad)) isdefined(x, :grad) && (x.grad = zero_grad!(x.grad))
else
isdefined(x, :grad) || (x.grad = init_grad(x.data))
end end
return return
end end
@ -21,9 +19,14 @@ function scan(x)
return return
end end
back_(f, y, args...) = back(f, args...) function back_(c::Call, Δ)
back_(c::Call, y, Δ) = back_(c.func, y, Δ, c.args...) Δs = c.func(Δ)
back_(::Call{Void}, y, Δ) = nothing (Δs isa Tuple && length(Δs) >= length(c.args)) ||
error("Gradient is not a tuple of length $(length(c.args))")
foreach(back, c.args, data.(Δs))
end
back_(::Call{Void}, Δ) = nothing
accum!(x, Δ) = x .+ Δ accum!(x, Δ) = x .+ Δ
accum!(x::AbstractArray, Δ) = (x .+= Δ) accum!(x::AbstractArray, Δ) = (x .+= Δ)
@ -31,33 +34,121 @@ accum!(x::AbstractArray, Δ) = (x .+= Δ)
function back(x::Tracked, Δ) function back(x::Tracked, Δ)
x.isleaf && (x.grad = accum!(x.grad, Δ); return) x.isleaf && (x.grad = accum!(x.grad, Δ); return)
ref = x.ref -= 1 ref = x.ref -= 1
if ref > 0 || isdefined(x, :grad)
if isdefined(x, :grad) if isdefined(x, :grad)
x.grad = accum!(x.grad, Δ) x.grad = accum!(x.grad, Δ)
ref == 0 && back_(x.f, x.data, x.grad)
else else
ref == 0 && back_(x.f, x.data, Δ) x.grad = Δ
end
ref == 0 && back_(x.f, x.grad)
else
ref == 0 && back_(x.f, Δ)
end end
return return
end end
back(x, Δ) = back(tracker(x), Δ) back(::Void, _) = return
back(x::Void, Δ) = error("Can't backpropagate through `nothing`")
macro back(x, Δ)
quote
x = $(esc(x))
istracked(x) && back(x, $(esc(Δ)))
end
end
# Interface methods # Interface methods
# TODO: if an error occurs in `back` the refcounts will be broken # TODO: if an error occurs in `back` the refcounts will be broken
# and `back` will silently fail to update. # and `back` will silently fail to update.
# Refcounts are also probably not safe in some situations (e.g. back called
# from within a backpropagator)
function back!(x::Tracked, Δ) function back!(x, Δ)
istracked(x) || return
scan(x) scan(x)
back(x, Δ) back(tracker(x), Δ)
return
end end
back!(x, Δ) = back!(tracker(x), Δ) # Out-of-place gradients
struct Params
params::IdSet
Params(xs) = new(IdSet(xs))
end
@forward Params.params Base.start, Base.next, Base.done
function Base.show(io::IO, ps::Params)
print(io, "Params([")
join(io, ps.params, ", ")
print(io, "])")
end
struct Grads
grads::ObjectIdDict
end
Base.show(io::IO, ps::Grads) = println(io, "Grads(...)")
Grads() = Grads(ObjectIdDict())
Grads(ps::Params) = Grads(ObjectIdDict(tracker(p) => init_grad(data(p)) for p in ps))
Base.getindex(g::Grads, x::Tracked) = g.grads[x]
function Base.getindex(g::Grads, x)
istracked(x) || error("Object not tracked: $x")
g[tracker(x)]
end
@forward Grads.grads Base.setindex!, Base.haskey
accum!(g::Grads, x, Δ) = g[x] = haskey(g, x) ? g[x] + Δ : Δ
function back_(g::Grads, c::Call, Δ)
Δs = c.func(Δ)
(Δs isa Tuple && length(Δs) >= length(c.args)) ||
error("Gradient is not a tuple of length $(length(c.args))")
foreach((x, Δ) -> back(g, x, Δ), c.args, Δs)
end
back_(g::Grads, ::Call{Void}, Δ) = nothing
function back(g::Grads, x::Tracked, Δ)
x.isleaf && (accum!(g, x, Δ); return)
ref = x.ref -= 1
if ref > 0 || haskey(g, x)
accum!(g, x, Δ)
ref == 0 && back_(g, x.f, g[x])
else
ref == 0 && back_(g, x.f, Δ)
end
return
end
back(::Grads, ::Void, _) = return
function forward(f, ps::Params)
y = f()
y, function (Δ)
g = Grads(ps)
if istracked(y)
scan(y)
back(g, tracker(y), Δ)
end
return g
end
end
function forward(f, args...)
args = param.(args)
y, back = forward(() -> f(args...), Params(args))
y, Δ -> getindex.(back(Δ), args)
end
function losscheck(x)
x isa Real || error("Function output is not scalar")
isinf(x) && error("Loss is infinite")
isnan(x) && error("Loss is NaN")
end
function gradient(f, args...)
y, back = forward(f, args...)
losscheck(y)
return back(1)
end
derivative(f, x) = gradient(f, x)[1]

25
src/tracker/idset.jl Normal file
View File

@ -0,0 +1,25 @@
struct IdSet{T} <: AbstractSet{T}
dict::ObjectIdDict
IdSet{T}() where T = new(ObjectIdDict())
end
Base.eltype{T}(::IdSet{T}) = T
IdSet() = IdSet{Any}()
Base.push!{T}(s::IdSet{T}, x::T) = (s.dict[x] = nothing; s)
Base.delete!{T}(s::IdSet{T}, x::T) = (delete!(s.dict, x); s)
Base.in(x, s::IdSet) = haskey(s.dict, x)
(::Type{IdSet{T}}){T}(xs) = push!(IdSet{T}(), xs...)
IdSet(xs) = IdSet{eltype(xs)}(xs)
Base.collect(s::IdSet) = Base.collect(keys(s.dict))
Base.similar(s::IdSet, T::Type) = IdSet{T}()
@forward IdSet.dict Base.length
Base.start(s::IdSet) = start(keys(s.dict))
Base.next(s::IdSet, st) = next(keys(s.dict), st)
Base.done(s::IdSet, st) = done(keys(s.dict), st)

View File

@ -1,9 +1,3 @@
function gradient(f, xs...)
xs = param.(xs)
back!(f(xs...))
grad.(xs)
end
function ngradient(f, xs::AbstractArray...) function ngradient(f, xs::AbstractArray...)
grads = zeros.(xs) grads = zeros.(xs)
for (x, Δ) in zip(xs, grads), i in 1:length(x) for (x, Δ) in zip(xs, grads), i in 1:length(x)
@ -21,4 +15,4 @@ end
gradcheck(f, xs...) = gradcheck(f, xs...) =
all(isapprox.(ngradient(f, xs...), all(isapprox.(ngradient(f, xs...),
gradient(f, xs...), rtol = 1e-5, atol = 1e-5)) data.(gradient(f, xs...)), rtol = 1e-5, atol = 1e-5))

View File

@ -1,12 +1,14 @@
struct TrackedReal{T<:Real} <: Real struct TrackedReal{T<:Real} <: Real
data::T
tracker::Tracked{T} tracker::Tracked{T}
end end
TrackedReal(x::Real) = TrackedReal(Tracked(Call(nothing), x, zero(x))) TrackedReal(x::Real) = TrackedReal(x, Tracked{typeof(x)}(Call(), zero(x)))
data(x::TrackedReal) = x.data
tracker(x::TrackedReal) = x.tracker tracker(x::TrackedReal) = x.tracker
track(f::Call, x::Real) = TrackedReal(Tracked(f, x, zero(x))) track(f::Call, x::Real) = TrackedReal(x, Tracked{typeof(x)}(f, zero(x)))
function back!(x::TrackedReal) function back!(x::TrackedReal)
isinf(x) && error("Loss is Inf") isinf(x) && error("Loss is Inf")
@ -47,23 +49,21 @@ using DiffRules, SpecialFunctions, NaNMath
for (M, f, arity) in DiffRules.diffrules() for (M, f, arity) in DiffRules.diffrules()
arity == 1 || continue arity == 1 || continue
@eval begin @eval begin
@grad $M.$f(a::Real) =
$M.$f(data(a)), Δ -> (Δ * $(DiffRules.diffrule(M, f, :a)),)
$M.$f(a::TrackedReal) = track($M.$f, a) $M.$f(a::TrackedReal) = track($M.$f, a)
back(::typeof($M.$f), Δ::Real, a::TrackedReal) =
back(a, Δ * $(DiffRules.diffrule(M, f, :(data(a)))))
end end
end end
for (M, f, arity) in DiffRules.diffrules() for (M, f, arity) in DiffRules.diffrules()
arity == 2 || continue arity == 2 || continue
da, db = DiffRules.diffrule(M, f, :(data(a)), :(data(b))) da, db = DiffRules.diffrule(M, f, :a, :b)
f = :($M.$f)
@eval begin @eval begin
$M.$f(a::TrackedReal, b::TrackedReal) = track($M.$f, a, b) @grad $f(a::Real, b::Real) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db)
$M.$f(a::TrackedReal, b::Real) = track($M.$f, a, b) $f(a::TrackedReal, b::TrackedReal) = track($f, a, b)
$M.$f(a::Real, b::TrackedReal) = track($M.$f, a, b) $f(a::TrackedReal, b::Real) = track($f, a, b)
function back(::typeof($M.$f), Δ::Real, a::Real, b::Real) $f(a::Real, b::TrackedReal) = track($f, a, b)
@back(a, Δ * $da)
@back(b, Δ * $db)
end
end end
end end
@ -75,16 +75,18 @@ import Base:^
# Tuples # Tuples
struct TrackedTuple{T<:Tuple} struct TrackedTuple{T<:Tuple}
data::T
tracker::Tracked{T} tracker::Tracked{T}
end end
data(xs::TrackedTuple) = xs.data
tracker(xs::TrackedTuple) = xs.tracker tracker(xs::TrackedTuple) = xs.tracker
accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ) accum!(x::Tuple, Δ::Tuple) = accum!.(x, Δ)
init_grad(x::Tuple) = init_grad.(x) init_grad(x::Tuple) = init_grad.(x)
zero_grad!(x::Tuple) = zero_grad!.(x) zero_grad!(x::Tuple) = zero_grad!.(x)
track(f::Call, xs::Tuple) = TrackedTuple(Tracked(f, xs)) track(f::Call, xs::Tuple) = TrackedTuple(xs, Tracked{typeof(xs)}(f, zero.(xs)))
function Base.show(io::IO, xs::TrackedTuple) function Base.show(io::IO, xs::TrackedTuple)
show(io, data(xs)) show(io, data(xs))
@ -95,20 +97,21 @@ Base.length(x::TrackedTuple) = length(data(x))
Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i) Base.getindex(xs::TrackedTuple, i::Integer) = track(getindex, xs, i)
back(::typeof(getindex), Δ, t, i) = @grad function getindex(xs::TrackedTuple, i)
back(t, ntuple(j -> i == j ? Δ : 0, length(t))) data(xs)[i], Δ -> (ntuple(j -> i == j ? Δ : 0, length(xs)), nothing)
end
# Array collection # Array collection
function collect(xs) function collect(xs)
xs = Base.collect(xs) xs = Base.collect(xs)
track(Call(collect, xs), data.(xs)) track(Call(collect, (tracker.(xs),)), data.(xs))
end end
function scan(c::Call{typeof(collect)}) function scan(c::Call{typeof(collect)})
foreach(scan, c.args[1]) foreach(scan, c.args[1])
end end
function back(::typeof(collect), Δ, xs) function back_(c::Call{typeof(collect)}, Δ)
foreach((x, Δ) -> @back(x, Δ), xs, Δ) foreach(back, c.args[1], data(Δ))
end end

View File

@ -1,4 +1,5 @@
import Adapt: adapt import Adapt: adapt
import .Tracker: IdSet
children(x) = () children(x) = ()
mapchildren(f, x) = x mapchildren(f, x) = x
@ -20,9 +21,7 @@ function mapleaves(f, x; cache = ObjectIdDict())
cache[x] = isleaf(x) ? f(x) : mapchildren(x -> mapleaves(f, x, cache = cache), x) cache[x] = isleaf(x) ? f(x) : mapchildren(x -> mapleaves(f, x, cache = cache), x)
end end
using DataFlow: OSet function prefor(f, x; seen = IdSet())
function prefor(f, x; seen = OSet())
x seen && return x seen && return
f(x) f(x)
foreach(x -> prefor(f, x, seen = seen), children(x)) foreach(x -> prefor(f, x, seen = seen), children(x))

View File

@ -1,5 +1,5 @@
using Flux.Tracker, Base.Test, NNlib using Flux.Tracker, Base.Test, NNlib
using Flux.Tracker: TrackedReal, gradcheck, grad using Flux.Tracker: TrackedReal, gradcheck, grad, derivative, checkpoint
using NNlib: conv using NNlib: conv
gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...) gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...)
@ -111,6 +111,7 @@ end
@test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6)) @test gradtest(x -> permutedims(x, [3,1,2]), rand(4,5,6))
# TODO unreliable
@test gradtest(x -> repmat(x, 5,5), rand(4,5)) @test gradtest(x -> repmat(x, 5,5), rand(4,5))
@test gradtest(x -> repmat(x, 5), rand(4,5)) @test gradtest(x -> repmat(x, 5), rand(4,5))
@ -232,4 +233,24 @@ Tracker.back!(b)
@test grad.((x,y)) == (3, 2) @test grad.((x,y)) == (3, 2)
end end
# Gradient Hooks
@testset "Hooks" begin
x = param(2)
y = Tracker.hook(-, x)
back!(y)
@test grad(x) == -1
end
@testset "Checkpointing" begin
count = 0
function mul(a, b)
count += 1
a * b
end
@test derivative(x -> mul(5, x), 3) == 5
@test count == 1
@test derivative(x -> checkpoint(mul, 5, x), 3) == 5
@test count == 3
end
end #testset end #testset