67 lines
11 KiB
HTML
67 lines
11 KiB
HTML
<!DOCTYPE html>
|
||
<html lang="en"><head><meta charset="UTF-8"/><meta name="viewport" content="width=device-width, initial-scale=1.0"/><title>Backpropagation · Flux</title><script>(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
|
||
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),
|
||
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
|
||
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
|
||
|
||
ga('create', 'UA-36890222-9', 'auto');
|
||
ga('send', 'pageview');
|
||
</script><link href="https://cdnjs.cloudflare.com/ajax/libs/normalize/4.2.0/normalize.min.css" rel="stylesheet" type="text/css"/><link href="https://fonts.googleapis.com/css?family=Lato|Roboto+Mono" rel="stylesheet" type="text/css"/><link href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.6.3/css/font-awesome.min.css" rel="stylesheet" type="text/css"/><link href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/default.min.css" rel="stylesheet" type="text/css"/><script>documenterBaseURL=".."</script><script src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.2.0/require.min.js" data-main="../assets/documenter.js"></script><script src="../siteinfo.js"></script><script src="../../versions.js"></script><link href="../assets/documenter.css" rel="stylesheet" type="text/css"/><link href="../../flux.css" rel="stylesheet" type="text/css"/></head><body><nav class="toc"><h1>Flux</h1><select id="version-selector" onChange="window.location.href=this.value" style="visibility: hidden"></select><form class="search" id="search-form" action="../search.html"><input id="search-query" name="q" type="text" placeholder="Search docs"/></form><ul><li><a class="toctext" href="../index.html">Home</a></li><li><span class="toctext">Building Models</span><ul><li><a class="toctext" href="../models/basics.html">Basics</a></li><li><a class="toctext" href="../models/recurrence.html">Recurrence</a></li><li><a class="toctext" href="../models/regularisation.html">Regularisation</a></li><li><a class="toctext" href="../models/layers.html">Model Reference</a></li></ul></li><li><span class="toctext">Training Models</span><ul><li><a class="toctext" href="../training/optimisers.html">Optimisers</a></li><li><a class="toctext" href="../training/training.html">Training</a></li></ul></li><li><a class="toctext" href="../data/onehot.html">One-Hot Encoding</a></li><li><a class="toctext" href="../gpu.html">GPU Support</a></li><li><a class="toctext" href="../saving.html">Saving & Loading</a></li><li><span class="toctext">Internals</span><ul><li class="current"><a class="toctext" href="tracker.html">Backpropagation</a><ul class="internal"><li><a class="toctext" href="#Taking-Gradients-1">Taking Gradients</a></li><li><a class="toctext" href="#Tracked-Arrays-1">Tracked Arrays</a></li><li><a class="toctext" href="#Custom-Gradients-1">Custom Gradients</a></li><li><a class="toctext" href="#Tracked-Internals-1">Tracked Internals</a></li></ul></li></ul></li><li><a class="toctext" href="../community.html">Community</a></li></ul></nav><article id="docs"><header><nav><ul><li>Internals</li><li><a href="tracker.html">Backpropagation</a></li></ul><a class="edit-page" href="https://github.com/FluxML/Flux.jl/blob/master/docs/src/internals/tracker.md"><span class="fa"></span> Edit on GitHub</a></nav><hr/><div id="topbar"><span>Backpropagation</span><a class="fa fa-bars" href="#"></a></div></header><h1><a class="nav-anchor" id="Flux.Tracker-1" href="#Flux.Tracker-1">Flux.Tracker</a></h1><p>Backpropagation, or reverse-mode automatic differentiation, is handled by the <code>Flux.Tracker</code> module.</p><pre><code class="language-julia">julia> using Flux.Tracker</code></pre><p>Here we discuss some more advanced uses of this module, as well as covering its internals.</p><h2><a class="nav-anchor" id="Taking-Gradients-1" href="#Taking-Gradients-1">Taking Gradients</a></h2><p>In the <a href="../models/basics.html">basics section</a> we covered basic usage of the <code>gradient</code> function.</p><pre><code class="language-julia">using Flux.Tracker
|
||
|
||
Tracker.gradient((a, b) -> a*b, 2, 3) # (3.0 (tracked), 2.0 (tracked))</code></pre><p><code>gradient</code> is actually just a thin wrapper around the backpropagator-based interface, <code>forward</code>.</p><pre><code class="language-julia">using Flux.Tracker: forward
|
||
|
||
y, back = forward((a, b) -> a*b, 2, 3) # (6.0 (tracked), Flux.Tracker.#9)
|
||
|
||
back(1) # (3.0 (tracked), 2.0 (tracked))</code></pre><p>The <code>forward</code> function returns two results. The first, <code>y</code>, is the original value of the function (perhaps with tracking applied). The second, <code>back</code>, is a new function which, given a sensitivity, returns the sensitivity of the inputs to <code>forward</code> (we call this a "backpropagator"). One use of this interface is to provide custom sensitivities when outputs are not scalar.</p><pre><code class="language-julia">julia> y, back = forward((a, b) -> a.*b, [1,2,3],[4,5,6])
|
||
(param([4.0, 10.0, 18.0]), Flux.Tracker.#9)
|
||
|
||
julia> back([1,1,1])
|
||
(param([4.0, 5.0, 6.0]), param([1.0, 2.0, 3.0]))</code></pre><p>We can also take gradients in-place. This can be useful if you only care about first-order gradients.</p><pre><code class="language-julia">a, b = param(2), param(3)
|
||
|
||
c = a*b # 6.0 (tracked)
|
||
|
||
Tracker.back!(c)
|
||
|
||
Tracker.grad(a), Tracker.grad(b) # (3.0, 2.0)</code></pre><h2><a class="nav-anchor" id="Tracked-Arrays-1" href="#Tracked-Arrays-1">Tracked Arrays</a></h2><p>The <code>param</code> function converts a normal Julia array into a new object that, while behaving like an array, tracks extra information that allows us to calculate derivatives. For example, say we multiply two parameters:</p><pre><code class="language-julia">julia> W = param([1 2; 3 4])
|
||
Tracked 2×2 Array{Float64,2}:
|
||
1.0 2.0
|
||
3.0 4.0
|
||
|
||
julia> x = param([5, 6])
|
||
Tracked 2-element Array{Float64,1}:
|
||
5.0
|
||
6.0
|
||
|
||
julia> y = W*x
|
||
Tracked 2-element Array{Float64,1}:
|
||
17.0
|
||
39.0</code></pre><p>The output <code>y</code> is also a <code>TrackedArray</code> object. We can now backpropagate sensitivities to <code>W</code> and <code>x</code> via the <code>back!</code> function, and see the gradients accumulated in the <code>W</code> and <code>x</code> tracked arrays:</p><pre><code class="language-julia">julia> Tracker.back!(y, [1, -1])
|
||
|
||
julia> W.grad
|
||
2×2 Array{Float64,2}:
|
||
5.0 6.0
|
||
-5.0 -6.0
|
||
|
||
julia> x.grad
|
||
2-element Array{Float64,1}:
|
||
-2.0
|
||
-2.0</code></pre><p>You may sometimes want to drop derivative information and just get the plain value back. You can do this by calling <code>Tracker.data(W)</code>.</p><h2><a class="nav-anchor" id="Custom-Gradients-1" href="#Custom-Gradients-1">Custom Gradients</a></h2><p>We can hook in to the processes above to implement custom gradients for a function or kernel. For a toy example, imagine a custom implementation of <code>minus</code>:</p><pre><code class="language-julia">minus(a, b) = a - b</code></pre><p>Firstly, we must tell the tracker system to stop when it sees a call to <code>minus</code>, and record it. We can do this using dispatch:</p><pre><code class="language-julia">using Flux.Tracker: TrackedArray, track, @grad
|
||
|
||
minus(a::TrackedArray, b::TrackedArray) = track(minus, a, b)</code></pre><p><code>track</code> takes care of building a new <code>Tracked</code> object and recording the operation on the tape. We just need to provide a gradient definition.</p><pre><code class="language-julia">@grad function minus(a, b)
|
||
return minus(data(a), data(b)), Δ -> (Δ, -Δ)
|
||
end</code></pre><p>This is essentially just a way of overloading the <code>forward</code> function we saw above. We strip tracking from <code>a</code> and <code>b</code> so that we are calling the original definition of <code>minus</code> (otherwise, we'd just try to track the call again and hit an infinite regress).</p><p>Note that in the backpropagator we don't call <code>data(a)</code>; we <em>do</em> in fact want to track this, since nest AD will take a derivative through the backpropagator itself. For example, the gradient of <code>*</code> might look like this.</p><pre><code class="language-julia">@grad a * b = data(a)*data(b), Δ -> (Δ*b, a*Δ)</code></pre><p>We can then calculate the first derivative of <code>minus</code> as follows:</p><pre><code class="language-julia">a = param([1,2,3])
|
||
b = param([3,2,1])
|
||
|
||
c = minus(a, b) # [-2.0 (tracked), 0.0 (tracked), 2.0 (tracked)]
|
||
|
||
Tracker.back!(c, 1)
|
||
Tracker.grad(a) # [1.00, 1.00, 1.00]
|
||
Tracker.grad(b) # [-1.00, -1.00, -1.00]</code></pre><p>For multi-argument functions with custom gradients, you likely want to catch not just <code>minus(::TrackedArray, ::TrackedArray)</code> but also <code>minus(::Array, TrackedArray)</code> and so on. To do so, just define those extra signatures as needed:</p><pre><code class="language-julia">minus(a::AbstractArray, b::TrackedArray) = Tracker.track(minus, a, b)
|
||
minus(a::TrackedArray, b::AbstractArray) = Tracker.track(minus, a, b)</code></pre><h2><a class="nav-anchor" id="Tracked-Internals-1" href="#Tracked-Internals-1">Tracked Internals</a></h2><p>All <code>Tracked*</code> objects (<code>TrackedArray</code>, <code>TrackedReal</code>) are light wrappers around the <code>Tracked</code> type, which you can access via the <code>.tracker</code> field.</p><pre><code class="language-julia">julia> x.tracker
|
||
Flux.Tracker.Tracked{Array{Float64,1}}(0x00000000, Flux.Tracker.Call{Nothing,Tuple{}}(nothing, ()), true, [5.0, 6.0], [-2.0, -2.0])</code></pre><p>The <code>Tracker</code> stores the gradient of a given object, which we've seen before.</p><pre><code class="language-julia">julia> x.tracker.grad
|
||
2-element Array{Float64,1}:
|
||
-2.0
|
||
-2.0</code></pre><p>The tracker also contains a <code>Call</code> object, which simply represents a function call that was made at some point during the forward pass. For example, the <code>+</code> call would look like this:</p><pre><code class="language-julia">julia> Tracker.Call(+, 1, 2)
|
||
Flux.Tracker.Call{Base.#+,Tuple{Int64,Int64}}(+, (1, 2))</code></pre><p>In the case of the <code>y</code> we produced above, we can see that it stores the call that produced it – that is, <code>W*x</code>.</p><pre><code class="language-julia">julia> y.tracker.f
|
||
Flux.Tracker.Call{...}(*, (param([1.0 2.0; 3.0 4.0]), param([5.0, 6.0])))</code></pre><p>Notice that because the arguments to the call may also be tracked arrays, storing their own calls, this means that <code>Tracker</code> ends up forming a data structure that records everything that happened during the forward pass (often known as a <em>tape</em>).</p><p>When we call <code>back!(y, [1, -1])</code>, the sensitivities <code>[1, -1]</code> simply get forwarded to <code>y</code>'s call (<code>*</code>), effectively calling</p><pre><code class="language-julia">Tracker.back(*, [1, -1], W, x)</code></pre><p>which in turn calculates the sensitivities of the arguments (<code>W</code> and <code>x</code>) and back-propagates through their calls. This is recursive, so it will walk the entire program graph and propagate gradients to the original model parameters.</p><footer><hr/><a class="previous" href="../saving.html"><span class="direction">Previous</span><span class="title">Saving & Loading</span></a><a class="next" href="../community.html"><span class="direction">Next</span><span class="title">Community</span></a></footer></article></body></html>
|