325 lines
9.8 KiB
HTML
325 lines
9.8 KiB
HTML
![]() |
<!DOCTYPE html>
|
|||
|
<html lang="en">
|
|||
|
<head>
|
|||
|
<meta charset="UTF-8"/>
|
|||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
|||
|
<title>
|
|||
|
Recurrence · Flux
|
|||
|
</title>
|
|||
|
<script>
|
|||
|
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
|
|||
|
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),
|
|||
|
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
|
|||
|
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
|
|||
|
|
|||
|
ga('create', 'UA-36890222-9', 'auto');
|
|||
|
ga('send', 'pageview');
|
|||
|
|
|||
|
</script>
|
|||
|
<link href="https://cdnjs.cloudflare.com/ajax/libs/normalize/4.2.0/normalize.min.css" rel="stylesheet" type="text/css"/>
|
|||
|
<link href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/9.5.0/styles/default.min.css" rel="stylesheet" type="text/css"/>
|
|||
|
<link href="https://fonts.googleapis.com/css?family=Lato|Ubuntu+Mono" rel="stylesheet" type="text/css"/>
|
|||
|
<link href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.6.3/css/font-awesome.min.css" rel="stylesheet" type="text/css"/>
|
|||
|
<link href="../assets/documenter.css" rel="stylesheet" type="text/css"/>
|
|||
|
<script>
|
|||
|
documenterBaseURL=".."
|
|||
|
</script>
|
|||
|
<script src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.2.0/require.min.js" data-main="../assets/documenter.js"></script>
|
|||
|
<script src="../../versions.js"></script>
|
|||
|
<link href="../../flux.css" rel="stylesheet" type="text/css"/>
|
|||
|
</head>
|
|||
|
<body>
|
|||
|
<nav class="toc">
|
|||
|
<h1>
|
|||
|
Flux
|
|||
|
</h1>
|
|||
|
<form class="search" action="../search.html">
|
|||
|
<select id="version-selector" onChange="window.location.href=this.value">
|
|||
|
<option value="#" selected="selected" disabled="disabled">
|
|||
|
Version
|
|||
|
</option>
|
|||
|
</select>
|
|||
|
<input id="search-query" name="q" type="text" placeholder="Search docs"/>
|
|||
|
</form>
|
|||
|
<ul>
|
|||
|
<li>
|
|||
|
<a class="toctext" href="../index.html">
|
|||
|
Home
|
|||
|
</a>
|
|||
|
</li>
|
|||
|
<li>
|
|||
|
<span class="toctext">
|
|||
|
Building Models
|
|||
|
</span>
|
|||
|
<ul>
|
|||
|
<li>
|
|||
|
<a class="toctext" href="basics.html">
|
|||
|
Model Building Basics
|
|||
|
</a>
|
|||
|
</li>
|
|||
|
<li>
|
|||
|
<a class="toctext" href="templates.html">
|
|||
|
Model Templates
|
|||
|
</a>
|
|||
|
</li>
|
|||
|
<li class="current">
|
|||
|
<a class="toctext" href="recurrent.html">
|
|||
|
Recurrence
|
|||
|
</a>
|
|||
|
<ul class="internal"></ul>
|
|||
|
</li>
|
|||
|
<li>
|
|||
|
<a class="toctext" href="debugging.html">
|
|||
|
Debugging
|
|||
|
</a>
|
|||
|
</li>
|
|||
|
</ul>
|
|||
|
</li>
|
|||
|
<li>
|
|||
|
<span class="toctext">
|
|||
|
Other APIs
|
|||
|
</span>
|
|||
|
<ul>
|
|||
|
<li>
|
|||
|
<a class="toctext" href="../apis/batching.html">
|
|||
|
Batching
|
|||
|
</a>
|
|||
|
</li>
|
|||
|
<li>
|
|||
|
<a class="toctext" href="../apis/backends.html">
|
|||
|
Backends
|
|||
|
</a>
|
|||
|
</li>
|
|||
|
<li>
|
|||
|
<a class="toctext" href="../apis/storage.html">
|
|||
|
Storing Models
|
|||
|
</a>
|
|||
|
</li>
|
|||
|
</ul>
|
|||
|
</li>
|
|||
|
<li>
|
|||
|
<span class="toctext">
|
|||
|
In Action
|
|||
|
</span>
|
|||
|
<ul>
|
|||
|
<li>
|
|||
|
<a class="toctext" href="../examples/logreg.html">
|
|||
|
Simple MNIST
|
|||
|
</a>
|
|||
|
</li>
|
|||
|
<li>
|
|||
|
<a class="toctext" href="../examples/char-rnn.html">
|
|||
|
Char RNN
|
|||
|
</a>
|
|||
|
</li>
|
|||
|
</ul>
|
|||
|
</li>
|
|||
|
<li>
|
|||
|
<a class="toctext" href="../contributing.html">
|
|||
|
Contributing & Help
|
|||
|
</a>
|
|||
|
</li>
|
|||
|
<li>
|
|||
|
<a class="toctext" href="../internals.html">
|
|||
|
Internals
|
|||
|
</a>
|
|||
|
</li>
|
|||
|
</ul>
|
|||
|
</nav>
|
|||
|
<article id="docs">
|
|||
|
<header>
|
|||
|
<nav>
|
|||
|
<ul>
|
|||
|
<li>
|
|||
|
Building Models
|
|||
|
</li>
|
|||
|
<li>
|
|||
|
<a href="recurrent.html">
|
|||
|
Recurrence
|
|||
|
</a>
|
|||
|
</li>
|
|||
|
</ul>
|
|||
|
<a class="edit-page" href="https://github.com/MikeInnes/Flux.jl/tree/efcb9650da31c183b94b839f66aa3467d007c33f/docs/src/models/recurrent.md">
|
|||
|
<span class="fa">
|
|||
|
|
|||
|
</span>
|
|||
|
Edit on GitHub
|
|||
|
</a>
|
|||
|
</nav>
|
|||
|
<hr/>
|
|||
|
</header>
|
|||
|
<h1>
|
|||
|
<a class="nav-anchor" id="Recurrent-Models-1" href="#Recurrent-Models-1">
|
|||
|
Recurrent Models
|
|||
|
</a>
|
|||
|
</h1>
|
|||
|
<p>
|
|||
|
<a href="https://en.wikipedia.org/wiki/Recurrent_neural_network">
|
|||
|
Recurrence
|
|||
|
</a>
|
|||
|
is a first-class feature in Flux and recurrent models are very easy to build and use. Recurrences are often illustrated as cycles or self-dependencies in the graph; they can also be thought of as a hidden output from / input to the network. For example, for a sequence of inputs
|
|||
|
<code>x1, x2, x3 ...</code>
|
|||
|
we produce predictions as follows:
|
|||
|
</p>
|
|||
|
<pre><code class="language-julia">y1 = f(W, x1) # `f` is the model, `W` represents the parameters
|
|||
|
y2 = f(W, x2)
|
|||
|
y3 = f(W, x3)
|
|||
|
...</code></pre>
|
|||
|
<p>
|
|||
|
Each evaluation is independent and the prediction made for a given input will always be the same. That makes a lot of sense for, say, MNIST images, but less sense when predicting a sequence. For that case we introduce the hidden state:
|
|||
|
</p>
|
|||
|
<pre><code class="language-julia">y1, s = f(W, x1, s)
|
|||
|
y2, s = f(W, x2, s)
|
|||
|
y3, s = f(W, x3, s)
|
|||
|
...</code></pre>
|
|||
|
<p>
|
|||
|
The state
|
|||
|
<code>s</code>
|
|||
|
allows the prediction to depend not only on the current input
|
|||
|
<code>x</code>
|
|||
|
but also on the history of past inputs.
|
|||
|
</p>
|
|||
|
<p>
|
|||
|
The simplest recurrent network looks as follows in Flux, and it should be familiar if you've seen the equations defining an RNN before:
|
|||
|
</p>
|
|||
|
<pre><code class="language-julia">@net type Recurrent
|
|||
|
Wxy; Wyy; by
|
|||
|
y
|
|||
|
function (x)
|
|||
|
y = tanh( x * Wxy + y{-1} * Wyy + by )
|
|||
|
end
|
|||
|
end</code></pre>
|
|||
|
<p>
|
|||
|
The only difference from a regular feed-forward layer is that we create a variable
|
|||
|
<code>y</code>
|
|||
|
which is defined as depending on itself. The
|
|||
|
<code>y{-1}</code>
|
|||
|
syntax means "take the value of
|
|||
|
<code>y</code>
|
|||
|
from the previous run of the network".
|
|||
|
</p>
|
|||
|
<p>
|
|||
|
Using recurrent layers is straightforward and no different feedforward ones in terms of the
|
|||
|
<code>Chain</code>
|
|||
|
macro etc. For example:
|
|||
|
</p>
|
|||
|
<pre><code class="language-julia">model = Chain(
|
|||
|
Affine(784, 20), σ
|
|||
|
Recurrent(20, 30),
|
|||
|
Recurrent(30, 15))</code></pre>
|
|||
|
<p>
|
|||
|
Before using the model we need to unroll it. This happens with the
|
|||
|
<code>unroll</code>
|
|||
|
function:
|
|||
|
</p>
|
|||
|
<pre><code class="language-julia">unroll(model, 20)</code></pre>
|
|||
|
<p>
|
|||
|
This call creates an unrolled, feed-forward version of the model which accepts N (= 20) inputs and generates N predictions at a time. Essentially, the model is replicated N times and Flux ties the hidden outputs
|
|||
|
<code>y</code>
|
|||
|
to hidden inputs.
|
|||
|
</p>
|
|||
|
<p>
|
|||
|
Here's a more complex recurrent layer, an LSTM, and again it should be familiar if you've seen the
|
|||
|
<a href="https://colah.github.io/posts/2015-08-Understanding-LSTMs/">
|
|||
|
equations
|
|||
|
</a>
|
|||
|
:
|
|||
|
</p>
|
|||
|
<pre><code class="language-julia">@net type LSTM
|
|||
|
Wxf; Wyf; bf
|
|||
|
Wxi; Wyi; bi
|
|||
|
Wxo; Wyo; bo
|
|||
|
Wxc; Wyc; bc
|
|||
|
y; state
|
|||
|
function (x)
|
|||
|
# Gates
|
|||
|
forget = σ( x * Wxf + y{-1} * Wyf + bf )
|
|||
|
input = σ( x * Wxi + y{-1} * Wyi + bi )
|
|||
|
output = σ( x * Wxo + y{-1} * Wyo + bo )
|
|||
|
# State update and output
|
|||
|
state′ = tanh( x * Wxc + y{-1} * Wyc + bc )
|
|||
|
state = forget .* state{-1} + input .* state′
|
|||
|
y = output .* tanh(state)
|
|||
|
end
|
|||
|
end</code></pre>
|
|||
|
<p>
|
|||
|
The only unfamiliar part is that we have to define all of the parameters of the LSTM upfront, which adds a few lines at the beginning.
|
|||
|
</p>
|
|||
|
<p>
|
|||
|
Flux's very mathematical notation generalises well to handling more complex models. For example,
|
|||
|
<a href="https://arxiv.org/abs/1409.0473">
|
|||
|
this neural translation model with alignment
|
|||
|
</a>
|
|||
|
can be fairly straightforwardly, and recognisably, translated from the paper into Flux code:
|
|||
|
</p>
|
|||
|
<pre><code class="language-julia"># A recurrent model which takes a token and returns a context-dependent
|
|||
|
# annotation.
|
|||
|
|
|||
|
@net type Encoder
|
|||
|
forward
|
|||
|
backward
|
|||
|
token -> hcat(forward(token), backward(token))
|
|||
|
end
|
|||
|
|
|||
|
Encoder(in::Integer, out::Integer) =
|
|||
|
Encoder(LSTM(in, out÷2), flip(LSTM(in, out÷2)))
|
|||
|
|
|||
|
# A recurrent model which takes a sequence of annotations, attends, and returns
|
|||
|
# a predicted output token.
|
|||
|
|
|||
|
@net type Decoder
|
|||
|
attend
|
|||
|
recur
|
|||
|
state; y; N
|
|||
|
function (anns)
|
|||
|
energies = map(ann -> exp(attend(hcat(state{-1}, ann))[1]), seq(anns, N))
|
|||
|
weights = energies./sum(energies)
|
|||
|
ctx = sum(map((α, ann) -> α .* ann, weights, anns))
|
|||
|
(_, state), y = recur((state{-1},y{-1}), ctx)
|
|||
|
y
|
|||
|
end
|
|||
|
end
|
|||
|
|
|||
|
Decoder(in::Integer, out::Integer; N = 1) =
|
|||
|
Decoder(Affine(in+out, 1),
|
|||
|
unroll1(LSTM(in, out)),
|
|||
|
param(zeros(1, out)), param(zeros(1, out)), N)
|
|||
|
|
|||
|
# The model
|
|||
|
|
|||
|
Nalpha = 5 # The size of the input token vector
|
|||
|
Nphrase = 7 # The length of (padded) phrases
|
|||
|
Nhidden = 12 # The size of the hidden state
|
|||
|
|
|||
|
encode = Encoder(Nalpha, Nhidden)
|
|||
|
decode = Chain(Decoder(Nhidden, Nhidden, N = Nphrase), Affine(Nhidden, Nalpha), softmax)
|
|||
|
|
|||
|
model = Chain(
|
|||
|
unroll(encode, Nphrase, stateful = false),
|
|||
|
unroll(decode, Nphrase, stateful = false, seq = false))</code></pre>
|
|||
|
<p>
|
|||
|
Note that this model excercises some of the more advanced parts of the compiler and isn't stable for general use yet.
|
|||
|
</p>
|
|||
|
<footer>
|
|||
|
<hr/>
|
|||
|
<a class="previous" href="templates.html">
|
|||
|
<span class="direction">
|
|||
|
Previous
|
|||
|
</span>
|
|||
|
<span class="title">
|
|||
|
Model Templates
|
|||
|
</span>
|
|||
|
</a>
|
|||
|
<a class="next" href="debugging.html">
|
|||
|
<span class="direction">
|
|||
|
Next
|
|||
|
</span>
|
|||
|
<span class="title">
|
|||
|
Debugging
|
|||
|
</span>
|
|||
|
</a>
|
|||
|
</footer>
|
|||
|
</article>
|
|||
|
</body>
|
|||
|
</html>
|