Flux.Tracker
Backpropagation, or reverse-mode automatic differentiation, is handled by the Flux.Tracker
module.
julia> using Flux.Tracker
The param
function converts a normal Julia array into a new object that, while behaving like an array, tracks extra information that allows us to calculate derivatives. For example, say we multiply two parameters:
julia> W = param([1 2; 3 4])
+
Flux.Tracker
Backpropagation, or reverse-mode automatic differentiation, is handled by the Flux.Tracker
module.
julia> using Flux.Tracker
Here we discuss some more advanced uses of this module, as well as covering its internals.
Taking Gradients
In the basics section we covered basic usage of the gradient
function.
using Flux.Tracker
+
+Tracker.gradient((a, b) -> a*b, 2, 3) # (3.0 (tracked), 2.0 (tracked))
gradient
is actually just a thin wrapper around the backpropagator-based interface, forward
.
using Flux.Tracker: forward
+
+y, back = forward((a, b) -> a*b, 2, 3) # (6.0 (tracked), Flux.Tracker.#9)
+
+back(1) # (3.0 (tracked), 2.0 (tracked))
The forward
function returns two results. The first, y
, is the original value of the function (perhaps with tracking applied). The second, back
, is a new function which, given a sensitivity, returns the sensitivity of the inputs to forward
(we call this a "backpropagator"). One use of this interface is to provide custom sensitivities when outputs are not scalar.
julia> y, back = forward((a, b) -> a.*b, [1,2,3],[4,5,6])
+(param([4.0, 10.0, 18.0]), Flux.Tracker.#9)
+
+julia> back([1,1,1])
+(param([4.0, 5.0, 6.0]), param([1.0, 2.0, 3.0]))
We can also take gradients in-place. This can be useful if you only care about first-order gradients.
a, b = param(2), param(3)
+
+c = a*b # 6.0 (tracked)
+
+Tracker.back!(c)
+
+Tracker.grad(a), Tracker.grad(b) # (3.0, 2.0)
Tracked Arrays
The param
function converts a normal Julia array into a new object that, while behaving like an array, tracks extra information that allows us to calculate derivatives. For example, say we multiply two parameters:
julia> W = param([1 2; 3 4])
Tracked 2×2 Array{Float64,2}:
1.0 2.0
3.0 4.0
@@ -29,40 +45,15 @@ julia> W.grad
julia> x.grad
2-element Array{Float64,1}:
-2.0
- -2.0
Internals
All Tracked*
objects (TrackedArray
, TrackedReal
) are light wrappers around the Tracked
type, which you can access via the .tracker
field.
julia> x.tracker
-Flux.Tracker.Tracked{Array{Float64,1}}(0x00000000, Flux.Tracker.Call{Void,Tuple{}}(nothing, ()), true, [5.0, 6.0], [-2.0, -2.0])
The Tracker
stores the value and gradient of a given object, which we've seen before.
julia> x.tracker.data
-2-element Array{Float64,1}:
- 5.0
- 6.0
+ -2.0
You may sometimes want to drop derivative information and just get the plain value back. You can do this by calling Tracker.data(W)
.
Custom Gradients
We can hook in to the processes above to implement custom gradients for a function or kernel. For a toy example, imagine a custom implementation of minus
:
minus(a, b) = a - b
Firstly, we must tell the tracker system to stop when it sees a call to minus
, and record it. We can do this using dispatch:
using Flux.Tracker: TrackedReal, track, @grad
-julia> x.tracker.grad
+minus(a::TrackedArray, b::TrackedArray) = Tracker.track(minus, a, b)
track
takes care of building a new Tracked
object and recording the operation on the tape. We just need to provide a gradient definition.
@grad function minus(a, b)
+ return minus(data(a),data(b)), Δ -> (Δ, -Δ)
+end
This is essentially just a way of overloading the forward
function we saw above. We strip tracking from a
and b
so that we are calling the original definition of minus
(otherwise, we'd just try to track the call again and hit an infinite regress).
Note that in the backpropagator we don't call data(a)
; we do in fact want to track this, since nest AD will take a derivative through the backpropagator itself. For example, the gradient of *
might look like this.
@grad a * b = data(a)*data(b), Δ -> (Δ*b, a*Δ)
For multi-argument functions with custom gradients, you likely want to catch not just minus(::TrackedArray, ::TrackedArray)
but also minus(::Array, TrackedArray)
and so on. To do so, just define those extra signatures as needed:
minus(a::AbstractArray, b::TrackedArray) = Tracker.track(minus, a, b)
+minus(a::TrackedArray, b::AbstractArray) = Tracker.track(minus, a, b)
Tracked Internals
All Tracked*
objects (TrackedArray
, TrackedReal
) are light wrappers around the Tracked
type, which you can access via the .tracker
field.
julia> x.tracker
+Flux.Tracker.Tracked{Array{Float64,1}}(0x00000000, Flux.Tracker.Call{Void,Tuple{}}(nothing, ()), true, [5.0, 6.0], [-2.0, -2.0])
The Tracker
stores the gradient of a given object, which we've seen before.
julia> x.tracker.grad
2-element Array{Float64,1}:
-2.0
-2.0
The tracker also contains a Call
object, which simply represents a function call that was made at some point during the forward pass. For example, the +
call would look like this:
julia> Tracker.Call(+, 1, 2)
Flux.Tracker.Call{Base.#+,Tuple{Int64,Int64}}(+, (1, 2))
In the case of the y
we produced above, we can see that it stores the call that produced it – that is, W*x
.
julia> y.tracker.f
-Flux.Tracker.Call{...}(*, (param([1.0 2.0; 3.0 4.0]), param([5.0, 6.0])))
Notice that because the arguments to the call may also be tracked arrays, storing their own calls, this means that Tracker
ends up forming a data structure that records everything that happened during the forward pass (often known as a tape).
When we call back!(y, [1, -1])
, the sensitivities [1, -1]
simply get forwarded to y
's call (*
), effectively calling
Tracker.back(*, [1, -1], W, x)
which in turn calculates the sensitivities of the arguments (W
and x
) and backpropagates through their calls. This is recursive, so it will walk the entire program graph and propagate gradients to the original model parameters.
Custom Gradients
We can hook in to the processes above to implement custom gradients for a function or kernel. For a toy example, imagine a custom implementation of minus
:
julia> minus(a, b) = a - b
Firstly, we must tell the tracker system to stop when it sees a call to minus
, and record it. We can do this using dispatch:
julia> minus(a::TrackedArray, b::TrackedArray) = Tracker.track(minus, a, b)
-minus (generic function with 2 methods)
Tracker.track
does two things: (1) it makes sure minus
is called with normal array, not tracked ones (you can use @show
inside minus
to verify this), and (2) it uses the result to add a minus
node to the tape. Look inside the result of calling minus
to see what happened:
julia> a, b = param([6,5,4]), param([1,2,3])
-(param([6.0, 5.0, 4.0]), param([1.0, 2.0, 3.0]))
-
-julia> c = minus(a, b)
-Tracked 3-element Array{Float64,1}:
- 5.0
- 3.0
- 1.0
-
-julia> c.tracker.f
-Flux.Tracker.Call{...}(minus, (param([6.0, 5.0, 4.0]), param([1.0, 2.0, 3.0])))
Finally, we have to specify the gradient of minus
.
julia> Tracker.back(::typeof(minus), Δ, a, b) =
- (Tracker.@back(a, Δ); Tracker.@back(b, -Δ))
@back(x, Δ)
tells the tracker to continue propagating the sensitivity Δ
through x
. Now, AD will work with any program that calls minus
.
julia> Flux.back!(c, 1)
-
-julia> a.grad
-3-element Array{Float64,1}:
- 1.0
- 1.0
- 1.0
-
-julia> b.grad
-3-element Array{Float64,1}:
- -1.0
- -1.0
- -1.0
Notes
For multi-argument functions with custom gradients, you likely want to catch not just minus(::TrackedArray, ::TrackedArray)
but also minus(::Array, TrackedArray)
and so on. To do so, just define those extra signatures as needed:
minus(a::AbstractArray, b::TrackedArray) = Tracker.track(minus, a, b)
-minus(a::TrackedArray, b::AbstractArray) = Tracker.track(minus, a, b)
@back
must be called exactly once on each tracked input argument. You do not need to do any special handling if one of the arguments is not tracked, as @back
will just become a no-op.