325 lines
9.8 KiB
HTML
325 lines
9.8 KiB
HTML
<!DOCTYPE html>
|
||
<html lang="en">
|
||
<head>
|
||
<meta charset="UTF-8"/>
|
||
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
||
<title>
|
||
Recurrence · Flux
|
||
</title>
|
||
<script>
|
||
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
|
||
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),
|
||
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
|
||
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
|
||
|
||
ga('create', 'UA-36890222-9', 'auto');
|
||
ga('send', 'pageview');
|
||
|
||
</script>
|
||
<link href="https://cdnjs.cloudflare.com/ajax/libs/normalize/4.2.0/normalize.min.css" rel="stylesheet" type="text/css"/>
|
||
<link href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/9.5.0/styles/default.min.css" rel="stylesheet" type="text/css"/>
|
||
<link href="https://fonts.googleapis.com/css?family=Lato|Ubuntu+Mono" rel="stylesheet" type="text/css"/>
|
||
<link href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.6.3/css/font-awesome.min.css" rel="stylesheet" type="text/css"/>
|
||
<link href="../assets/documenter.css" rel="stylesheet" type="text/css"/>
|
||
<script>
|
||
documenterBaseURL=".."
|
||
</script>
|
||
<script src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.2.0/require.min.js" data-main="../assets/documenter.js"></script>
|
||
<script src="../../versions.js"></script>
|
||
<link href="../../flux.css" rel="stylesheet" type="text/css"/>
|
||
</head>
|
||
<body>
|
||
<nav class="toc">
|
||
<h1>
|
||
Flux
|
||
</h1>
|
||
<form class="search" action="../search.html">
|
||
<select id="version-selector" onChange="window.location.href=this.value">
|
||
<option value="#" selected="selected" disabled="disabled">
|
||
Version
|
||
</option>
|
||
</select>
|
||
<input id="search-query" name="q" type="text" placeholder="Search docs"/>
|
||
</form>
|
||
<ul>
|
||
<li>
|
||
<a class="toctext" href="../index.html">
|
||
Home
|
||
</a>
|
||
</li>
|
||
<li>
|
||
<span class="toctext">
|
||
Building Models
|
||
</span>
|
||
<ul>
|
||
<li>
|
||
<a class="toctext" href="basics.html">
|
||
Model Building Basics
|
||
</a>
|
||
</li>
|
||
<li>
|
||
<a class="toctext" href="templates.html">
|
||
Model Templates
|
||
</a>
|
||
</li>
|
||
<li class="current">
|
||
<a class="toctext" href="recurrent.html">
|
||
Recurrence
|
||
</a>
|
||
<ul class="internal"></ul>
|
||
</li>
|
||
<li>
|
||
<a class="toctext" href="debugging.html">
|
||
Debugging
|
||
</a>
|
||
</li>
|
||
</ul>
|
||
</li>
|
||
<li>
|
||
<span class="toctext">
|
||
Other APIs
|
||
</span>
|
||
<ul>
|
||
<li>
|
||
<a class="toctext" href="../apis/batching.html">
|
||
Batching
|
||
</a>
|
||
</li>
|
||
<li>
|
||
<a class="toctext" href="../apis/backends.html">
|
||
Backends
|
||
</a>
|
||
</li>
|
||
<li>
|
||
<a class="toctext" href="../apis/storage.html">
|
||
Storing Models
|
||
</a>
|
||
</li>
|
||
</ul>
|
||
</li>
|
||
<li>
|
||
<span class="toctext">
|
||
In Action
|
||
</span>
|
||
<ul>
|
||
<li>
|
||
<a class="toctext" href="../examples/logreg.html">
|
||
Simple MNIST
|
||
</a>
|
||
</li>
|
||
<li>
|
||
<a class="toctext" href="../examples/char-rnn.html">
|
||
Char RNN
|
||
</a>
|
||
</li>
|
||
</ul>
|
||
</li>
|
||
<li>
|
||
<a class="toctext" href="../contributing.html">
|
||
Contributing & Help
|
||
</a>
|
||
</li>
|
||
<li>
|
||
<a class="toctext" href="../internals.html">
|
||
Internals
|
||
</a>
|
||
</li>
|
||
</ul>
|
||
</nav>
|
||
<article id="docs">
|
||
<header>
|
||
<nav>
|
||
<ul>
|
||
<li>
|
||
Building Models
|
||
</li>
|
||
<li>
|
||
<a href="recurrent.html">
|
||
Recurrence
|
||
</a>
|
||
</li>
|
||
</ul>
|
||
<a class="edit-page" href="https://github.com/MikeInnes/Flux.jl/tree/efcb9650da31c183b94b839f66aa3467d007c33f/docs/src/models/recurrent.md">
|
||
<span class="fa">
|
||
|
||
</span>
|
||
Edit on GitHub
|
||
</a>
|
||
</nav>
|
||
<hr/>
|
||
</header>
|
||
<h1>
|
||
<a class="nav-anchor" id="Recurrent-Models-1" href="#Recurrent-Models-1">
|
||
Recurrent Models
|
||
</a>
|
||
</h1>
|
||
<p>
|
||
<a href="https://en.wikipedia.org/wiki/Recurrent_neural_network">
|
||
Recurrence
|
||
</a>
|
||
is a first-class feature in Flux and recurrent models are very easy to build and use. Recurrences are often illustrated as cycles or self-dependencies in the graph; they can also be thought of as a hidden output from / input to the network. For example, for a sequence of inputs
|
||
<code>x1, x2, x3 ...</code>
|
||
we produce predictions as follows:
|
||
</p>
|
||
<pre><code class="language-julia">y1 = f(W, x1) # `f` is the model, `W` represents the parameters
|
||
y2 = f(W, x2)
|
||
y3 = f(W, x3)
|
||
...</code></pre>
|
||
<p>
|
||
Each evaluation is independent and the prediction made for a given input will always be the same. That makes a lot of sense for, say, MNIST images, but less sense when predicting a sequence. For that case we introduce the hidden state:
|
||
</p>
|
||
<pre><code class="language-julia">y1, s = f(W, x1, s)
|
||
y2, s = f(W, x2, s)
|
||
y3, s = f(W, x3, s)
|
||
...</code></pre>
|
||
<p>
|
||
The state
|
||
<code>s</code>
|
||
allows the prediction to depend not only on the current input
|
||
<code>x</code>
|
||
but also on the history of past inputs.
|
||
</p>
|
||
<p>
|
||
The simplest recurrent network looks as follows in Flux, and it should be familiar if you've seen the equations defining an RNN before:
|
||
</p>
|
||
<pre><code class="language-julia">@net type Recurrent
|
||
Wxy; Wyy; by
|
||
y
|
||
function (x)
|
||
y = tanh( x * Wxy + y{-1} * Wyy + by )
|
||
end
|
||
end</code></pre>
|
||
<p>
|
||
The only difference from a regular feed-forward layer is that we create a variable
|
||
<code>y</code>
|
||
which is defined as depending on itself. The
|
||
<code>y{-1}</code>
|
||
syntax means "take the value of
|
||
<code>y</code>
|
||
from the previous run of the network".
|
||
</p>
|
||
<p>
|
||
Using recurrent layers is straightforward and no different feedforward ones in terms of the
|
||
<code>Chain</code>
|
||
macro etc. For example:
|
||
</p>
|
||
<pre><code class="language-julia">model = Chain(
|
||
Affine(784, 20), σ
|
||
Recurrent(20, 30),
|
||
Recurrent(30, 15))</code></pre>
|
||
<p>
|
||
Before using the model we need to unroll it. This happens with the
|
||
<code>unroll</code>
|
||
function:
|
||
</p>
|
||
<pre><code class="language-julia">unroll(model, 20)</code></pre>
|
||
<p>
|
||
This call creates an unrolled, feed-forward version of the model which accepts N (= 20) inputs and generates N predictions at a time. Essentially, the model is replicated N times and Flux ties the hidden outputs
|
||
<code>y</code>
|
||
to hidden inputs.
|
||
</p>
|
||
<p>
|
||
Here's a more complex recurrent layer, an LSTM, and again it should be familiar if you've seen the
|
||
<a href="https://colah.github.io/posts/2015-08-Understanding-LSTMs/">
|
||
equations
|
||
</a>
|
||
:
|
||
</p>
|
||
<pre><code class="language-julia">@net type LSTM
|
||
Wxf; Wyf; bf
|
||
Wxi; Wyi; bi
|
||
Wxo; Wyo; bo
|
||
Wxc; Wyc; bc
|
||
y; state
|
||
function (x)
|
||
# Gates
|
||
forget = σ( x * Wxf + y{-1} * Wyf + bf )
|
||
input = σ( x * Wxi + y{-1} * Wyi + bi )
|
||
output = σ( x * Wxo + y{-1} * Wyo + bo )
|
||
# State update and output
|
||
state′ = tanh( x * Wxc + y{-1} * Wyc + bc )
|
||
state = forget .* state{-1} + input .* state′
|
||
y = output .* tanh(state)
|
||
end
|
||
end</code></pre>
|
||
<p>
|
||
The only unfamiliar part is that we have to define all of the parameters of the LSTM upfront, which adds a few lines at the beginning.
|
||
</p>
|
||
<p>
|
||
Flux's very mathematical notation generalises well to handling more complex models. For example,
|
||
<a href="https://arxiv.org/abs/1409.0473">
|
||
this neural translation model with alignment
|
||
</a>
|
||
can be fairly straightforwardly, and recognisably, translated from the paper into Flux code:
|
||
</p>
|
||
<pre><code class="language-julia"># A recurrent model which takes a token and returns a context-dependent
|
||
# annotation.
|
||
|
||
@net type Encoder
|
||
forward
|
||
backward
|
||
token -> hcat(forward(token), backward(token))
|
||
end
|
||
|
||
Encoder(in::Integer, out::Integer) =
|
||
Encoder(LSTM(in, out÷2), flip(LSTM(in, out÷2)))
|
||
|
||
# A recurrent model which takes a sequence of annotations, attends, and returns
|
||
# a predicted output token.
|
||
|
||
@net type Decoder
|
||
attend
|
||
recur
|
||
state; y; N
|
||
function (anns)
|
||
energies = map(ann -> exp(attend(hcat(state{-1}, ann))[1]), seq(anns, N))
|
||
weights = energies./sum(energies)
|
||
ctx = sum(map((α, ann) -> α .* ann, weights, anns))
|
||
(_, state), y = recur((state{-1},y{-1}), ctx)
|
||
y
|
||
end
|
||
end
|
||
|
||
Decoder(in::Integer, out::Integer; N = 1) =
|
||
Decoder(Affine(in+out, 1),
|
||
unroll1(LSTM(in, out)),
|
||
param(zeros(1, out)), param(zeros(1, out)), N)
|
||
|
||
# The model
|
||
|
||
Nalpha = 5 # The size of the input token vector
|
||
Nphrase = 7 # The length of (padded) phrases
|
||
Nhidden = 12 # The size of the hidden state
|
||
|
||
encode = Encoder(Nalpha, Nhidden)
|
||
decode = Chain(Decoder(Nhidden, Nhidden, N = Nphrase), Affine(Nhidden, Nalpha), softmax)
|
||
|
||
model = Chain(
|
||
unroll(encode, Nphrase, stateful = false),
|
||
unroll(decode, Nphrase, stateful = false, seq = false))</code></pre>
|
||
<p>
|
||
Note that this model excercises some of the more advanced parts of the compiler and isn't stable for general use yet.
|
||
</p>
|
||
<footer>
|
||
<hr/>
|
||
<a class="previous" href="templates.html">
|
||
<span class="direction">
|
||
Previous
|
||
</span>
|
||
<span class="title">
|
||
Model Templates
|
||
</span>
|
||
</a>
|
||
<a class="next" href="debugging.html">
|
||
<span class="direction">
|
||
Next
|
||
</span>
|
||
<span class="title">
|
||
Debugging
|
||
</span>
|
||
</a>
|
||
</footer>
|
||
</article>
|
||
</body>
|
||
</html>
|