Flux.jl/latest/models/recurrent.html

325 lines
9.8 KiB
HTML
Raw Normal View History

2017-01-16 16:51:09 +00:00
<!DOCTYPE html>
2017-01-17 20:06:28 +00:00
<html lang="en">
<head>
<meta charset="UTF-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<title>
2017-01-18 01:18:15 +00:00
Recurrence · Flux
2017-01-17 20:06:28 +00:00
</title>
<script>
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
2017-01-16 16:51:09 +00:00
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-36890222-9', 'auto');
ga('send', 'pageview');
2017-01-17 20:06:28 +00:00
</script>
<link href="https://cdnjs.cloudflare.com/ajax/libs/normalize/4.2.0/normalize.min.css" rel="stylesheet" type="text/css"/>
<link href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/9.5.0/styles/default.min.css" rel="stylesheet" type="text/css"/>
<link href="https://fonts.googleapis.com/css?family=Lato|Ubuntu+Mono" rel="stylesheet" type="text/css"/>
<link href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.6.3/css/font-awesome.min.css" rel="stylesheet" type="text/css"/>
2017-01-18 01:18:15 +00:00
<link href="../assets/documenter.css" rel="stylesheet" type="text/css"/>
2017-01-17 20:06:28 +00:00
<script>
2017-01-18 01:18:15 +00:00
documenterBaseURL=".."
2017-01-17 20:06:28 +00:00
</script>
2017-01-18 01:18:15 +00:00
<script src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.2.0/require.min.js" data-main="../assets/documenter.js"></script>
<script src="../../versions.js"></script>
<link href="../../flux.css" rel="stylesheet" type="text/css"/>
2017-01-17 20:06:28 +00:00
</head>
<body>
<nav class="toc">
<h1>
Flux
</h1>
2017-01-18 01:18:15 +00:00
<form class="search" action="../search.html">
2017-01-17 20:06:28 +00:00
<select id="version-selector" onChange="window.location.href=this.value">
<option value="#" selected="selected" disabled="disabled">
Version
</option>
</select>
<input id="search-query" name="q" type="text" placeholder="Search docs"/>
</form>
<ul>
<li>
2017-01-18 01:18:15 +00:00
<a class="toctext" href="../index.html">
2017-01-17 20:06:28 +00:00
Home
</a>
</li>
2017-01-18 12:45:25 +00:00
<li>
2017-01-18 23:22:30 +00:00
<span class="toctext">
Building Models
</span>
<ul>
<li>
<a class="toctext" href="basics.html">
2017-02-01 13:48:25 +00:00
Model Building Basics
2017-01-18 23:22:30 +00:00
</a>
</li>
2017-02-02 07:48:56 +00:00
<li>
<a class="toctext" href="templates.html">
Model Templates
</a>
</li>
2017-01-18 23:22:30 +00:00
<li class="current">
<a class="toctext" href="recurrent.html">
2017-01-18 12:45:25 +00:00
Recurrence
2017-01-18 23:22:30 +00:00
</a>
<ul class="internal"></ul>
</li>
<li>
<a class="toctext" href="debugging.html">
2017-01-18 12:45:25 +00:00
Debugging
2017-01-18 23:22:30 +00:00
</a>
</li>
</ul>
2017-01-18 12:45:25 +00:00
</li>
2017-02-18 15:11:53 +00:00
<li>
2017-02-20 10:53:09 +00:00
<span class="toctext">
Other APIs
</span>
<ul>
<li>
2017-02-20 11:05:06 +00:00
<a class="toctext" href="../apis/batching.html">
2017-02-18 15:11:53 +00:00
Batching
2017-02-20 10:53:09 +00:00
</a>
</li>
<li>
2017-02-20 11:05:06 +00:00
<a class="toctext" href="../apis/backends.html">
2017-02-18 15:11:53 +00:00
Backends
2017-02-20 10:53:09 +00:00
</a>
</li>
2017-02-28 16:50:27 +00:00
<li>
<a class="toctext" href="../apis/storage.html">
Storing Models
</a>
</li>
2017-02-20 10:53:09 +00:00
</ul>
2017-02-18 15:11:53 +00:00
</li>
2017-01-17 20:06:28 +00:00
<li>
<span class="toctext">
2017-01-18 12:45:25 +00:00
In Action
2017-01-17 20:06:28 +00:00
</span>
<ul>
2017-01-18 01:18:15 +00:00
<li>
2017-01-18 12:45:25 +00:00
<a class="toctext" href="../examples/logreg.html">
Logistic Regression
2017-01-18 01:18:15 +00:00
</a>
2017-01-17 20:06:28 +00:00
</li>
2017-02-28 16:21:45 +00:00
<li>
<a class="toctext" href="../examples/char-rnn.html">
Char RNN
</a>
</li>
2017-01-17 20:06:28 +00:00
</ul>
</li>
2017-01-18 01:18:15 +00:00
<li>
<a class="toctext" href="../contributing.html">
Contributing &amp; Help
</a>
</li>
2017-01-18 12:45:25 +00:00
<li>
<a class="toctext" href="../internals.html">
Internals
</a>
</li>
2017-01-17 20:06:28 +00:00
</ul>
</nav>
<article id="docs">
<header>
<nav>
<ul>
<li>
2017-01-18 23:22:30 +00:00
Building Models
</li>
<li>
2017-01-18 01:18:15 +00:00
<a href="recurrent.html">
Recurrence
2017-01-17 20:06:28 +00:00
</a>
</li>
</ul>
2017-03-08 20:55:13 +00:00
<a class="edit-page" href="https://github.com/MikeInnes/Flux.jl/tree/9c9feb9ba0eeb88207223630109baf92d1b87516/docs/src/models/recurrent.md">
2017-01-17 20:06:28 +00:00
<span class="fa">
</span>
Edit on GitHub
</a>
</nav>
<hr/>
</header>
<h1>
2017-01-18 01:18:15 +00:00
<a class="nav-anchor" id="Recurrent-Models-1" href="#Recurrent-Models-1">
Recurrent Models
2017-01-17 20:06:28 +00:00
</a>
</h1>
2017-01-18 01:18:15 +00:00
<p>
2017-02-02 08:19:53 +00:00
<a href="https://en.wikipedia.org/wiki/Recurrent_neural_network">
Recurrence
</a>
is a first-class feature in Flux and recurrent models are very easy to build and use. Recurrences are often illustrated as cycles or self-dependencies in the graph; they can also be thought of as a hidden output from / input to the network. For example, for a sequence of inputs
<code>x1, x2, x3 ...</code>
we produce predictions as follows:
</p>
<pre><code class="language-julia">y1 = f(W, x1) # `f` is the model, `W` represents the parameters
y2 = f(W, x2)
y3 = f(W, x3)
...</code></pre>
<p>
Each evaluation is independent and the prediction made for a given input will always be the same. That makes a lot of sense for, say, MNIST images, but less sense when predicting a sequence. For that case we introduce the hidden state:
</p>
<pre><code class="language-julia">y1, s = f(W, x1, s)
y2, s = f(W, x2, s)
y3, s = f(W, x3, s)
...</code></pre>
<p>
The state
<code>s</code>
allows the prediction to depend not only on the current input
<code>x</code>
but also on the history of past inputs.
</p>
<p>
The simplest recurrent network looks as follows in Flux, and it should be familiar if you&#39;ve seen the equations defining an RNN before:
</p>
<pre><code class="language-julia">@net type Recurrent
Wxy; Wyy; by
y
function (x)
y = tanh( x * Wxy + y{-1} * Wyy + by )
end
end</code></pre>
<p>
The only difference from a regular feed-forward layer is that we create a variable
<code>y</code>
which is defined as depending on itself. The
<code>y{-1}</code>
syntax means &quot;take the value of
<code>y</code>
from the previous run of the network&quot;.
</p>
<p>
Using recurrent layers is straightforward and no different feedforard ones in terms of the
<code>Chain</code>
macro etc. For example:
</p>
<pre><code class="language-julia">model = Chain(
Affine(784, 20), σ
Recurrent(20, 30),
Recurrent(30, 15))</code></pre>
<p>
Before using the model we need to unroll it. This happens with the
<code>unroll</code>
function:
</p>
<pre><code class="language-julia">unroll(model, 20)</code></pre>
<p>
This call creates an unrolled, feed-forward version of the model which accepts N (= 20) inputs and generates N predictions at a time. Essentially, the model is replicated N times and Flux ties the hidden outputs
<code>y</code>
to hidden inputs.
</p>
<p>
Here&#39;s a more complex recurrent layer, an LSTM, and again it should be familiar if you&#39;ve seen the
<a href="https://colah.github.io/posts/2015-08-Understanding-LSTMs/">
equations
</a>
:
</p>
<pre><code class="language-julia">@net type LSTM
Wxf; Wyf; bf
Wxi; Wyi; bi
Wxo; Wyo; bo
Wxc; Wyc; bc
y; state
function (x)
# Gates
forget = σ( x * Wxf + y{-1} * Wyf + bf )
input = σ( x * Wxi + y{-1} * Wyi + bi )
output = σ( x * Wxo + y{-1} * Wyo + bo )
# State update and output
state = tanh( x * Wxc + y{-1} * Wyc + bc )
state = forget .* state{-1} + input .* state
y = output .* tanh(state)
end
end</code></pre>
<p>
The only unfamiliar part is that we have to define all of the parameters of the LSTM upfront, which adds a few lines at the beginning.
</p>
<p>
Flux&#39;s very mathematical notation generalises well to handling more complex models. For example,
<a href="https://arxiv.org/abs/1409.0473">
this neural translation model with alignment
</a>
can be fairly straightforwardly, and recognisably, translated from the paper into Flux code:
</p>
<pre><code class="language-julia"># A recurrent model which takes a token and returns a context-dependent
# annotation.
@net type Encoder
forward
backward
token -&gt; hcat(forward(token), backward(token))
end
Encoder(in::Integer, out::Integer) =
Encoder(LSTM(in, out÷2), flip(LSTM(in, out÷2)))
# A recurrent model which takes a sequence of annotations, attends, and returns
# a predicted output token.
@net type Decoder
attend
recur
state; y; N
function (anns)
energies = map(ann -&gt; exp(attend(hcat(state{-1}, ann))[1]), seq(anns, N))
weights = energies./sum(energies)
ctx = sum(map((α, ann) -&gt; α .* ann, weights, anns))
(_, state), y = recur((state{-1},y{-1}), ctx)
y
end
end
Decoder(in::Integer, out::Integer; N = 1) =
Decoder(Affine(in+out, 1),
unroll1(LSTM(in, out)),
param(zeros(1, out)), param(zeros(1, out)), N)
# The model
Nalpha = 5 # The size of the input token vector
Nphrase = 7 # The length of (padded) phrases
Nhidden = 12 # The size of the hidden state
encode = Encoder(Nalpha, Nhidden)
decode = Chain(Decoder(Nhidden, Nhidden, N = Nphrase), Affine(Nhidden, Nalpha), softmax)
model = Chain(
unroll(encode, Nphrase, stateful = false),
unroll(decode, Nphrase, stateful = false, seq = false))</code></pre>
<p>
Note that this model excercises some of the more advanced parts of the compiler and isn&#39;t stable for general use yet.
2017-01-18 01:18:15 +00:00
</p>
2017-01-17 20:06:28 +00:00
<footer>
<hr/>
2017-02-02 07:48:56 +00:00
<a class="previous" href="templates.html">
2017-01-17 20:06:28 +00:00
<span class="direction">
Previous
</span>
<span class="title">
2017-02-02 07:48:56 +00:00
Model Templates
2017-01-18 01:18:15 +00:00
</span>
</a>
<a class="next" href="debugging.html">
<span class="direction">
Next
</span>
<span class="title">
Debugging
2017-01-17 20:06:28 +00:00
</span>
</a>
</footer>
</article>
</body>
</html>