Flux.jl/release-0.2/models/recurrent.html

325 lines
9.8 KiB
HTML
Raw Normal View History

2017-05-02 13:01:23 +00:00
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<title>
Recurrence · Flux
</title>
<script>
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-36890222-9', 'auto');
ga('send', 'pageview');
</script>
<link href="https://cdnjs.cloudflare.com/ajax/libs/normalize/4.2.0/normalize.min.css" rel="stylesheet" type="text/css"/>
<link href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/9.5.0/styles/default.min.css" rel="stylesheet" type="text/css"/>
<link href="https://fonts.googleapis.com/css?family=Lato|Ubuntu+Mono" rel="stylesheet" type="text/css"/>
<link href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.6.3/css/font-awesome.min.css" rel="stylesheet" type="text/css"/>
<link href="../assets/documenter.css" rel="stylesheet" type="text/css"/>
<script>
documenterBaseURL=".."
</script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.2.0/require.min.js" data-main="../assets/documenter.js"></script>
<script src="../../versions.js"></script>
<link href="../../flux.css" rel="stylesheet" type="text/css"/>
</head>
<body>
<nav class="toc">
<h1>
Flux
</h1>
<form class="search" action="../search.html">
<select id="version-selector" onChange="window.location.href=this.value">
<option value="#" selected="selected" disabled="disabled">
Version
</option>
</select>
<input id="search-query" name="q" type="text" placeholder="Search docs"/>
</form>
<ul>
<li>
<a class="toctext" href="../index.html">
Home
</a>
</li>
<li>
<span class="toctext">
Building Models
</span>
<ul>
<li>
<a class="toctext" href="basics.html">
Model Building Basics
</a>
</li>
<li>
<a class="toctext" href="templates.html">
Model Templates
</a>
</li>
<li class="current">
<a class="toctext" href="recurrent.html">
Recurrence
</a>
<ul class="internal"></ul>
</li>
<li>
<a class="toctext" href="debugging.html">
Debugging
</a>
</li>
</ul>
</li>
<li>
<span class="toctext">
Other APIs
</span>
<ul>
<li>
<a class="toctext" href="../apis/batching.html">
Batching
</a>
</li>
<li>
<a class="toctext" href="../apis/backends.html">
Backends
</a>
</li>
<li>
<a class="toctext" href="../apis/storage.html">
Storing Models
</a>
</li>
</ul>
</li>
<li>
<span class="toctext">
In Action
</span>
<ul>
<li>
<a class="toctext" href="../examples/logreg.html">
Simple MNIST
</a>
</li>
<li>
<a class="toctext" href="../examples/char-rnn.html">
Char RNN
</a>
</li>
</ul>
</li>
<li>
<a class="toctext" href="../contributing.html">
Contributing &amp; Help
</a>
</li>
<li>
<a class="toctext" href="../internals.html">
Internals
</a>
</li>
</ul>
</nav>
<article id="docs">
<header>
<nav>
<ul>
<li>
Building Models
</li>
<li>
<a href="recurrent.html">
Recurrence
</a>
</li>
</ul>
<a class="edit-page" href="https://github.com/MikeInnes/Flux.jl/tree/efcb9650da31c183b94b839f66aa3467d007c33f/docs/src/models/recurrent.md">
<span class="fa">
</span>
Edit on GitHub
</a>
</nav>
<hr/>
</header>
<h1>
<a class="nav-anchor" id="Recurrent-Models-1" href="#Recurrent-Models-1">
Recurrent Models
</a>
</h1>
<p>
<a href="https://en.wikipedia.org/wiki/Recurrent_neural_network">
Recurrence
</a>
is a first-class feature in Flux and recurrent models are very easy to build and use. Recurrences are often illustrated as cycles or self-dependencies in the graph; they can also be thought of as a hidden output from / input to the network. For example, for a sequence of inputs
<code>x1, x2, x3 ...</code>
we produce predictions as follows:
</p>
<pre><code class="language-julia">y1 = f(W, x1) # `f` is the model, `W` represents the parameters
y2 = f(W, x2)
y3 = f(W, x3)
...</code></pre>
<p>
Each evaluation is independent and the prediction made for a given input will always be the same. That makes a lot of sense for, say, MNIST images, but less sense when predicting a sequence. For that case we introduce the hidden state:
</p>
<pre><code class="language-julia">y1, s = f(W, x1, s)
y2, s = f(W, x2, s)
y3, s = f(W, x3, s)
...</code></pre>
<p>
The state
<code>s</code>
allows the prediction to depend not only on the current input
<code>x</code>
but also on the history of past inputs.
</p>
<p>
The simplest recurrent network looks as follows in Flux, and it should be familiar if you&#39;ve seen the equations defining an RNN before:
</p>
<pre><code class="language-julia">@net type Recurrent
Wxy; Wyy; by
y
function (x)
y = tanh( x * Wxy + y{-1} * Wyy + by )
end
end</code></pre>
<p>
The only difference from a regular feed-forward layer is that we create a variable
<code>y</code>
which is defined as depending on itself. The
<code>y{-1}</code>
syntax means &quot;take the value of
<code>y</code>
from the previous run of the network&quot;.
</p>
<p>
Using recurrent layers is straightforward and no different feedforward ones in terms of the
<code>Chain</code>
macro etc. For example:
</p>
<pre><code class="language-julia">model = Chain(
Affine(784, 20), σ
Recurrent(20, 30),
Recurrent(30, 15))</code></pre>
<p>
Before using the model we need to unroll it. This happens with the
<code>unroll</code>
function:
</p>
<pre><code class="language-julia">unroll(model, 20)</code></pre>
<p>
This call creates an unrolled, feed-forward version of the model which accepts N (= 20) inputs and generates N predictions at a time. Essentially, the model is replicated N times and Flux ties the hidden outputs
<code>y</code>
to hidden inputs.
</p>
<p>
Here&#39;s a more complex recurrent layer, an LSTM, and again it should be familiar if you&#39;ve seen the
<a href="https://colah.github.io/posts/2015-08-Understanding-LSTMs/">
equations
</a>
:
</p>
<pre><code class="language-julia">@net type LSTM
Wxf; Wyf; bf
Wxi; Wyi; bi
Wxo; Wyo; bo
Wxc; Wyc; bc
y; state
function (x)
# Gates
forget = σ( x * Wxf + y{-1} * Wyf + bf )
input = σ( x * Wxi + y{-1} * Wyi + bi )
output = σ( x * Wxo + y{-1} * Wyo + bo )
# State update and output
state = tanh( x * Wxc + y{-1} * Wyc + bc )
state = forget .* state{-1} + input .* state
y = output .* tanh(state)
end
end</code></pre>
<p>
The only unfamiliar part is that we have to define all of the parameters of the LSTM upfront, which adds a few lines at the beginning.
</p>
<p>
Flux&#39;s very mathematical notation generalises well to handling more complex models. For example,
<a href="https://arxiv.org/abs/1409.0473">
this neural translation model with alignment
</a>
can be fairly straightforwardly, and recognisably, translated from the paper into Flux code:
</p>
<pre><code class="language-julia"># A recurrent model which takes a token and returns a context-dependent
# annotation.
@net type Encoder
forward
backward
token -&gt; hcat(forward(token), backward(token))
end
Encoder(in::Integer, out::Integer) =
Encoder(LSTM(in, out÷2), flip(LSTM(in, out÷2)))
# A recurrent model which takes a sequence of annotations, attends, and returns
# a predicted output token.
@net type Decoder
attend
recur
state; y; N
function (anns)
energies = map(ann -&gt; exp(attend(hcat(state{-1}, ann))[1]), seq(anns, N))
weights = energies./sum(energies)
ctx = sum(map((α, ann) -&gt; α .* ann, weights, anns))
(_, state), y = recur((state{-1},y{-1}), ctx)
y
end
end
Decoder(in::Integer, out::Integer; N = 1) =
Decoder(Affine(in+out, 1),
unroll1(LSTM(in, out)),
param(zeros(1, out)), param(zeros(1, out)), N)
# The model
Nalpha = 5 # The size of the input token vector
Nphrase = 7 # The length of (padded) phrases
Nhidden = 12 # The size of the hidden state
encode = Encoder(Nalpha, Nhidden)
decode = Chain(Decoder(Nhidden, Nhidden, N = Nphrase), Affine(Nhidden, Nalpha), softmax)
model = Chain(
unroll(encode, Nphrase, stateful = false),
unroll(decode, Nphrase, stateful = false, seq = false))</code></pre>
<p>
Note that this model excercises some of the more advanced parts of the compiler and isn&#39;t stable for general use yet.
</p>
<footer>
<hr/>
<a class="previous" href="templates.html">
<span class="direction">
Previous
</span>
<span class="title">
Model Templates
</span>
</a>
<a class="next" href="debugging.html">
<span class="direction">
Next
</span>
<span class="title">
Debugging
</span>
</a>
</footer>
</article>
</body>
</html>