43 lines
8.3 KiB
HTML
43 lines
8.3 KiB
HTML
<!DOCTYPE html>
|
||
<html lang="en"><head><meta charset="UTF-8"/><meta name="viewport" content="width=device-width, initial-scale=1.0"/><title>Recurrence · Flux</title><script>(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
|
||
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),
|
||
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
|
||
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
|
||
|
||
ga('create', 'UA-36890222-9', 'auto');
|
||
ga('send', 'pageview');
|
||
</script><link href="https://cdnjs.cloudflare.com/ajax/libs/normalize/4.2.0/normalize.min.css" rel="stylesheet" type="text/css"/><link href="https://fonts.googleapis.com/css?family=Lato|Roboto+Mono" rel="stylesheet" type="text/css"/><link href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.6.3/css/font-awesome.min.css" rel="stylesheet" type="text/css"/><link href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/9.12.0/styles/default.min.css" rel="stylesheet" type="text/css"/><script>documenterBaseURL=".."</script><script src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.2.0/require.min.js" data-main="../assets/documenter.js"></script><script src="../siteinfo.js"></script><script src="../../versions.js"></script><link href="../assets/documenter.css" rel="stylesheet" type="text/css"/><link href="../../flux.css" rel="stylesheet" type="text/css"/></head><body><nav class="toc"><h1>Flux</h1><select id="version-selector" onChange="window.location.href=this.value" style="visibility: hidden"></select><form class="search" id="search-form" action="../search.html"><input id="search-query" name="q" type="text" placeholder="Search docs"/></form><ul><li><a class="toctext" href="../index.html">Home</a></li><li><span class="toctext">Building Models</span><ul><li><a class="toctext" href="basics.html">Basics</a></li><li class="current"><a class="toctext" href="recurrence.html">Recurrence</a><ul class="internal"><li><a class="toctext" href="#Recurrent-Cells-1">Recurrent Cells</a></li><li><a class="toctext" href="#Stateful-Models-1">Stateful Models</a></li><li><a class="toctext" href="#Sequences-1">Sequences</a></li><li><a class="toctext" href="#Truncating-Gradients-1">Truncating Gradients</a></li></ul></li><li><a class="toctext" href="layers.html">Model Reference</a></li></ul></li><li><span class="toctext">Training Models</span><ul><li><a class="toctext" href="../training/optimisers.html">Optimisers</a></li><li><a class="toctext" href="../training/training.html">Training</a></li></ul></li><li><a class="toctext" href="../data/onehot.html">One-Hot Encoding</a></li><li><a class="toctext" href="../gpu.html">GPU Support</a></li><li><a class="toctext" href="../community.html">Community</a></li></ul></nav><article id="docs"><header><nav><ul><li>Building Models</li><li><a href="recurrence.html">Recurrence</a></li></ul><a class="edit-page" href="https://github.com/FluxML/Flux.jl/blob/master/docs/src/models/recurrence.md"><span class="fa"></span> Edit on GitHub</a></nav><hr/><div id="topbar"><span>Recurrence</span><a class="fa fa-bars" href="#"></a></div></header><h1><a class="nav-anchor" id="Recurrent-Models-1" href="#Recurrent-Models-1">Recurrent Models</a></h1><h2><a class="nav-anchor" id="Recurrent-Cells-1" href="#Recurrent-Cells-1">Recurrent Cells</a></h2><p>In the simple feedforward case, our model <code>m</code> is a simple function from various inputs <code>xᵢ</code> to predictions <code>yᵢ</code>. (For example, each <code>x</code> might be an MNIST digit and each <code>y</code> a digit label.) Each prediction is completely independent of any others, and using the same <code>x</code> will always produce the same <code>y</code>.</p><pre><code class="language-julia">y₁ = f(x₁)
|
||
y₂ = f(x₂)
|
||
y₃ = f(x₃)
|
||
# ...</code></pre><p>Recurrent networks introduce a <em>hidden state</em> that gets carried over each time we run the model. The model now takes the old <code>h</code> as an input, and produces a new <code>h</code> as output, each time we run it.</p><pre><code class="language-julia">h = # ... initial state ...
|
||
h, y₁ = f(h, x₁)
|
||
h, y₂ = f(h, x₂)
|
||
h, y₃ = f(h, x₃)
|
||
# ...</code></pre><p>Information stored in <code>h</code> is preserved for the next prediction, allowing it to function as a kind of memory. This also means that the prediction made for a given <code>x</code> depends on all the inputs previously fed into the model.</p><p>(This might be important if, for example, each <code>x</code> represents one word of a sentence; the model's interpretation of the word "bank" should change if the previous input was "river" rather than "investment".)</p><p>Flux's RNN support closely follows this mathematical perspective. The most basic RNN is as close as possible to a standard <code>Dense</code> layer, and the output is also the hidden state.</p><pre><code class="language-julia">Wxh = randn(5, 10)
|
||
Whh = randn(5, 5)
|
||
b = randn(5)
|
||
|
||
function rnn(h, x)
|
||
h = tanh.(Wxh * x .+ Whh * h .+ b)
|
||
return h, h
|
||
end
|
||
|
||
x = rand(10) # dummy data
|
||
h = rand(5) # initial hidden state
|
||
|
||
h, y = rnn(h, x)</code></pre><p>If you run the last line a few times, you'll notice the output <code>y</code> changing slightly even though the input <code>x</code> is the same.</p><p>We sometimes refer to functions like <code>rnn</code> above, which explicitly manage state, as recurrent <em>cells</em>. There are various recurrent cells available, which are documented in the <a href="layers.html">layer reference</a>. The hand-written example above can be replaced with:</p><pre><code class="language-julia">using Flux
|
||
|
||
rnn2 = Flux.RNNCell(10, 5)
|
||
|
||
x = rand(10) # dummy data
|
||
h = rand(5) # initial hidden state
|
||
|
||
h, y = rnn2(h, x)</code></pre><h2><a class="nav-anchor" id="Stateful-Models-1" href="#Stateful-Models-1">Stateful Models</a></h2><p>For the most part, we don't want to manage hidden states ourselves, but to treat our models as being stateful. Flux provides the <code>Recur</code> wrapper to do this.</p><pre><code class="language-julia">x = rand(10)
|
||
h = rand(5)
|
||
|
||
m = Flux.Recur(rnn, h)
|
||
|
||
y = m(x)</code></pre><p>The <code>Recur</code> wrapper stores the state between runs in the <code>m.state</code> field.</p><p>If you use the <code>RNN(10, 5)</code> constructor – as opposed to <code>RNNCell</code> – you'll see that it's simply a wrapped cell.</p><pre><code class="language-julia">julia> RNN(10, 5)
|
||
Recur(RNNCell(Dense(15, 5)))</code></pre><h2><a class="nav-anchor" id="Sequences-1" href="#Sequences-1">Sequences</a></h2><p>Often we want to work with sequences of inputs, rather than individual <code>x</code>s.</p><pre><code class="language-julia">seq = [rand(10) for i = 1:10]</code></pre><p>With <code>Recur</code>, applying our model to each element of a sequence is trivial:</p><pre><code class="language-julia">m.(seq) # returns a list of 5-element vectors</code></pre><p>This works even when we've chain recurrent layers into a larger model.</p><pre><code class="language-julia">m = Chain(LSTM(10, 15), Dense(15, 5))
|
||
m.(seq)</code></pre><h2><a class="nav-anchor" id="Truncating-Gradients-1" href="#Truncating-Gradients-1">Truncating Gradients</a></h2><p>By default, calculating the gradients in a recurrent layer involves the entire history. For example, if we call the model on 100 inputs, calling <code>back!</code> will calculate the gradient for those 100 calls. If we then calculate another 10 inputs we have to calculate 110 gradients – this accumulates and quickly becomes expensive.</p><p>To avoid this we can <em>truncate</em> the gradient calculation, forgetting the history.</p><pre><code class="language-julia">truncate!(m)</code></pre><p>Calling <code>truncate!</code> wipes the slate clean, so we can call the model with more inputs without building up an expensive gradient computation.</p><p><code>truncate!</code> makes sense when you are working with multiple chunks of a large sequence, but we may also want to work with a set of independent sequences. In this case the hidden state should be completely reset to its original value, throwing away any accumulated information. <code>reset!</code> does this for you.</p><footer><hr/><a class="previous" href="basics.html"><span class="direction">Previous</span><span class="title">Basics</span></a><a class="next" href="layers.html"><span class="direction">Next</span><span class="title">Model Reference</span></a></footer></article></body></html>
|