Flux.jl/v0.5.3/internals/tracker.html
2018-07-05 11:36:42 +00:00

69 lines
9.6 KiB
HTML
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html lang="en"><head><meta charset="UTF-8"/><meta name="viewport" content="width=device-width, initial-scale=1.0"/><title>Backpropagation · Flux</title><script>(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-36890222-9', 'auto');
ga('send', 'pageview');
</script><link href="https://cdnjs.cloudflare.com/ajax/libs/normalize/4.2.0/normalize.min.css" rel="stylesheet" type="text/css"/><link href="https://fonts.googleapis.com/css?family=Lato|Roboto+Mono" rel="stylesheet" type="text/css"/><link href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.6.3/css/font-awesome.min.css" rel="stylesheet" type="text/css"/><link href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/default.min.css" rel="stylesheet" type="text/css"/><script>documenterBaseURL=".."</script><script src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.2.0/require.min.js" data-main="../assets/documenter.js"></script><script src="../siteinfo.js"></script><script src="../../versions.js"></script><link href="../assets/documenter.css" rel="stylesheet" type="text/css"/><link href="../../flux.css" rel="stylesheet" type="text/css"/></head><body><nav class="toc"><h1>Flux</h1><select id="version-selector" onChange="window.location.href=this.value" style="visibility: hidden"></select><form class="search" id="search-form" action="../search.html"><input id="search-query" name="q" type="text" placeholder="Search docs"/></form><ul><li><a class="toctext" href="../index.html">Home</a></li><li><span class="toctext">Building Models</span><ul><li><a class="toctext" href="../models/basics.html">Basics</a></li><li><a class="toctext" href="../models/recurrence.html">Recurrence</a></li><li><a class="toctext" href="../models/regularisation.html">Regularisation</a></li><li><a class="toctext" href="../models/layers.html">Model Reference</a></li></ul></li><li><span class="toctext">Training Models</span><ul><li><a class="toctext" href="../training/optimisers.html">Optimisers</a></li><li><a class="toctext" href="../training/training.html">Training</a></li></ul></li><li><a class="toctext" href="../data/onehot.html">One-Hot Encoding</a></li><li><a class="toctext" href="../gpu.html">GPU Support</a></li><li><a class="toctext" href="../saving.html">Saving &amp; Loading</a></li><li><span class="toctext">Internals</span><ul><li class="current"><a class="toctext" href="tracker.html">Backpropagation</a><ul class="internal"><li><a class="toctext" href="#Internals-1">Internals</a></li><li><a class="toctext" href="#Custom-Gradients-1">Custom Gradients</a></li><li><a class="toctext" href="#Notes-1">Notes</a></li></ul></li></ul></li><li><a class="toctext" href="../community.html">Community</a></li></ul></nav><article id="docs"><header><nav><ul><li>Internals</li><li><a href="tracker.html">Backpropagation</a></li></ul><a class="edit-page" href="https://github.com/FluxML/Flux.jl/blob/master/docs/src/internals/tracker.md"><span class="fa"></span> Edit on GitHub</a></nav><hr/><div id="topbar"><span>Backpropagation</span><a class="fa fa-bars" href="#"></a></div></header><h1><a class="nav-anchor" id="Flux.Tracker-1" href="#Flux.Tracker-1">Flux.Tracker</a></h1><p>Backpropagation, or reverse-mode automatic differentiation, is handled by the <code>Flux.Tracker</code> module.</p><pre><code class="language-julia">julia&gt; using Flux.Tracker</code></pre><p>The <code>param</code> function converts a normal Julia array into a new object that, while behaving like an array, tracks extra information that allows us to calculate derivatives. For example, say we multiply two parameters:</p><pre><code class="language-julia">julia&gt; W = param([1 2; 3 4])
Tracked 2×2 Array{Float64,2}:
1.0 2.0
3.0 4.0
julia&gt; x = param([5, 6])
Tracked 2-element Array{Float64,1}:
5.0
6.0
julia&gt; y = W*x
Tracked 2-element Array{Float64,1}:
17.0
39.0</code></pre><p>The output <code>y</code> is also a <code>TrackedArray</code> object. We can now backpropagate sensitivities to <code>W</code> and <code>x</code> via the <code>back!</code> function, and see the gradients accumulated in the <code>W</code> and <code>x</code> tracked arrays:</p><pre><code class="language-julia">julia&gt; Tracker.back!(y, [1, -1])
julia&gt; W.grad
2×2 Array{Float64,2}:
5.0 6.0
-5.0 -6.0
julia&gt; x.grad
2-element Array{Float64,1}:
-2.0
-2.0</code></pre><h2><a class="nav-anchor" id="Internals-1" href="#Internals-1">Internals</a></h2><p>All <code>Tracked*</code> objects (<code>TrackedArray</code>, <code>TrackedReal</code>) are light wrappers around the <code>Tracked</code> type, which you can access via the <code>.tracker</code> field.</p><pre><code class="language-julia">julia&gt; x.tracker
Flux.Tracker.Tracked{Array{Float64,1}}(0x00000000, Flux.Tracker.Call{Void,Tuple{}}(nothing, ()), true, [5.0, 6.0], [-2.0, -2.0])</code></pre><p>The <code>Tracker</code> stores the value and gradient of a given object, which we&#39;ve seen before.</p><pre><code class="language-julia">julia&gt; x.tracker.data
2-element Array{Float64,1}:
5.0
6.0
julia&gt; x.tracker.grad
2-element Array{Float64,1}:
-2.0
-2.0</code></pre><p>The tracker also contains a <code>Call</code> object, which simply represents a function call that was made at some point during the forward pass. For example, the <code>+</code> call would look like this:</p><pre><code class="language-julia">julia&gt; Tracker.Call(+, 1, 2)
Flux.Tracker.Call{Base.#+,Tuple{Int64,Int64}}(+, (1, 2))</code></pre><p>In the case of the <code>y</code> we produced above, we can see that it stores the call that produced it that is, <code>W*x</code>.</p><pre><code class="language-julia">julia&gt; y.tracker.f
Flux.Tracker.Call{...}(*, (param([1.0 2.0; 3.0 4.0]), param([5.0, 6.0])))</code></pre><p>Notice that because the arguments to the call may also be tracked arrays, storing their own calls, this means that <code>Tracker</code> ends up forming a data structure that records everything that happened during the forward pass (often known as a <em>tape</em>).</p><p>When we call <code>back!(y, [1, -1])</code>, the sensitivities <code>[1, -1]</code> simply get forwarded to <code>y</code>&#39;s call (<code>*</code>), effectively calling</p><pre><code class="language-julia">Tracker.back(*, [1, -1], W, x)</code></pre><p>which in turn calculates the sensitivities of the arguments (<code>W</code> and <code>x</code>) and backpropagates through their calls. This is recursive, so it will walk the entire program graph and propagate gradients to the original model parameters.</p><h2><a class="nav-anchor" id="Custom-Gradients-1" href="#Custom-Gradients-1">Custom Gradients</a></h2><p>We can hook in to the processes above to implement custom gradients for a function or kernel. For a toy example, imagine a custom implementation of <code>minus</code>:</p><pre><code class="language-julia">julia&gt; minus(a, b) = a - b</code></pre><p>Firstly, we must tell the tracker system to stop when it sees a call to <code>minus</code>, and record it. We can do this using dispatch:</p><pre><code class="language-julia">julia&gt; minus(a::TrackedArray, b::TrackedArray) = Tracker.track(minus, a, b)
minus (generic function with 2 methods)</code></pre><p><code>Tracker.track</code> does two things: (1) it makes sure <code>minus</code> is called with <em>normal</em> array, not tracked ones (you can use <code>@show</code> inside <code>minus</code> to verify this), and (2) it uses the result to add a <code>minus</code> node to the tape. Look inside the result of calling <code>minus</code> to see what happened:</p><pre><code class="language-julia">julia&gt; a, b = param([6,5,4]), param([1,2,3])
(param([6.0, 5.0, 4.0]), param([1.0, 2.0, 3.0]))
julia&gt; c = minus(a, b)
Tracked 3-element Array{Float64,1}:
5.0
3.0
1.0
julia&gt; c.tracker.f
Flux.Tracker.Call{...}(minus, (param([6.0, 5.0, 4.0]), param([1.0, 2.0, 3.0])))</code></pre><p>Finally, we have to specify the gradient of <code>minus</code>.</p><pre><code class="language-julia">julia&gt; Tracker.back(::typeof(minus), Δ, a, b) =
(Tracker.@back(a, Δ); Tracker.@back(b, -Δ))</code></pre><p><code>@back(x, Δ)</code> tells the tracker to continue propagating the sensitivity <code>Δ</code> through <code>x</code>. Now, AD will work with any program that calls <code>minus</code>.</p><pre><code class="language-julia">julia&gt; Flux.back!(c, 1)
julia&gt; a.grad
3-element Array{Float64,1}:
1.0
1.0
1.0
julia&gt; b.grad
3-element Array{Float64,1}:
-1.0
-1.0
-1.0</code></pre><h2><a class="nav-anchor" id="Notes-1" href="#Notes-1">Notes</a></h2><p>For multi-argument functions with custom gradients, you likely want to catch not just <code>minus(::TrackedArray, ::TrackedArray)</code> but also <code>minus(::Array, TrackedArray)</code> and so on. To do so, just define those extra signatures as needed:</p><pre><code class="language-julia">minus(a::AbstractArray, b::TrackedArray) = Tracker.track(minus, a, b)
minus(a::TrackedArray, b::AbstractArray) = Tracker.track(minus, a, b)</code></pre><p><code>@back</code> <em>must</em> be called exactly once on each tracked input argument. You do not need to do any special handling if one of the arguments is not tracked, as <code>@back</code> will just become a no-op.</p><footer><hr/><a class="previous" href="../saving.html"><span class="direction">Previous</span><span class="title">Saving &amp; Loading</span></a><a class="next" href="../community.html"><span class="direction">Next</span><span class="title">Community</span></a></footer></article></body></html>