2017-01-16 16:51:09 +00:00
<!DOCTYPE html>
2017-01-17 20:06:28 +00:00
< html lang = "en" >
< head >
< meta charset = "UTF-8" / >
< meta name = "viewport" content = "width=device-width, initial-scale=1.0" / >
< title >
2017-01-18 01:18:15 +00:00
Recurrence · Flux
2017-01-17 20:06:28 +00:00
< / title >
< script >
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
2017-01-16 16:51:09 +00:00
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-36890222-9', 'auto');
ga('send', 'pageview');
2017-01-17 20:06:28 +00:00
< / script >
< link href = "https://cdnjs.cloudflare.com/ajax/libs/normalize/4.2.0/normalize.min.css" rel = "stylesheet" type = "text/css" / >
< link href = "https://cdnjs.cloudflare.com/ajax/libs/highlight.js/9.5.0/styles/default.min.css" rel = "stylesheet" type = "text/css" / >
< link href = "https://fonts.googleapis.com/css?family=Lato|Ubuntu+Mono" rel = "stylesheet" type = "text/css" / >
< link href = "https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.6.3/css/font-awesome.min.css" rel = "stylesheet" type = "text/css" / >
2017-01-18 01:18:15 +00:00
< link href = "../assets/documenter.css" rel = "stylesheet" type = "text/css" / >
2017-01-17 20:06:28 +00:00
< script >
2017-01-18 01:18:15 +00:00
documenterBaseURL=".."
2017-01-17 20:06:28 +00:00
< / script >
2017-01-18 01:18:15 +00:00
< script src = "https://cdnjs.cloudflare.com/ajax/libs/require.js/2.2.0/require.min.js" data-main = "../assets/documenter.js" > < / script >
< script src = "../../versions.js" > < / script >
< link href = "../../flux.css" rel = "stylesheet" type = "text/css" / >
2017-01-17 20:06:28 +00:00
< / head >
< body >
< nav class = "toc" >
< h1 >
Flux
< / h1 >
2017-01-18 01:18:15 +00:00
< form class = "search" action = "../search.html" >
2017-01-17 20:06:28 +00:00
< select id = "version-selector" onChange = "window.location.href=this.value" >
< option value = "#" selected = "selected" disabled = "disabled" >
Version
< / option >
< / select >
< input id = "search-query" name = "q" type = "text" placeholder = "Search docs" / >
< / form >
< ul >
< li >
2017-01-18 01:18:15 +00:00
< a class = "toctext" href = "../index.html" >
2017-01-17 20:06:28 +00:00
Home
< / a >
< / li >
2017-01-18 12:45:25 +00:00
< li >
2017-01-18 23:22:30 +00:00
< span class = "toctext" >
Building Models
< / span >
< ul >
< li >
< a class = "toctext" href = "basics.html" >
2017-02-01 13:48:25 +00:00
Model Building Basics
2017-01-18 23:22:30 +00:00
< / a >
< / li >
2017-02-02 07:48:56 +00:00
< li >
< a class = "toctext" href = "templates.html" >
Model Templates
< / a >
< / li >
2017-01-18 23:22:30 +00:00
< li class = "current" >
< a class = "toctext" href = "recurrent.html" >
2017-01-18 12:45:25 +00:00
Recurrence
2017-01-18 23:22:30 +00:00
< / a >
< ul class = "internal" > < / ul >
< / li >
< li >
< a class = "toctext" href = "debugging.html" >
2017-01-18 12:45:25 +00:00
Debugging
2017-01-18 23:22:30 +00:00
< / a >
< / li >
< / ul >
2017-01-18 12:45:25 +00:00
< / li >
2017-02-18 15:11:53 +00:00
< li >
2017-02-20 10:53:09 +00:00
< span class = "toctext" >
Other APIs
< / span >
< ul >
< li >
2017-02-20 11:05:06 +00:00
< a class = "toctext" href = "../apis/batching.html" >
2017-02-18 15:11:53 +00:00
Batching
2017-02-20 10:53:09 +00:00
< / a >
< / li >
< li >
2017-02-20 11:05:06 +00:00
< a class = "toctext" href = "../apis/backends.html" >
2017-02-18 15:11:53 +00:00
Backends
2017-02-20 10:53:09 +00:00
< / a >
< / li >
2017-02-28 16:50:27 +00:00
< li >
< a class = "toctext" href = "../apis/storage.html" >
Storing Models
< / a >
< / li >
2017-02-20 10:53:09 +00:00
< / ul >
2017-02-18 15:11:53 +00:00
< / li >
2017-01-17 20:06:28 +00:00
< li >
< span class = "toctext" >
2017-01-18 12:45:25 +00:00
In Action
2017-01-17 20:06:28 +00:00
< / span >
< ul >
2017-01-18 01:18:15 +00:00
< li >
2017-01-18 12:45:25 +00:00
< a class = "toctext" href = "../examples/logreg.html" >
Logistic Regression
2017-01-18 01:18:15 +00:00
< / a >
2017-01-17 20:06:28 +00:00
< / li >
2017-02-28 16:21:45 +00:00
< li >
< a class = "toctext" href = "../examples/char-rnn.html" >
Char RNN
< / a >
< / li >
2017-01-17 20:06:28 +00:00
< / ul >
< / li >
2017-01-18 01:18:15 +00:00
< li >
< a class = "toctext" href = "../contributing.html" >
Contributing & Help
< / a >
< / li >
2017-01-18 12:45:25 +00:00
< li >
< a class = "toctext" href = "../internals.html" >
Internals
< / a >
< / li >
2017-01-17 20:06:28 +00:00
< / ul >
< / nav >
< article id = "docs" >
< header >
< nav >
< ul >
< li >
2017-01-18 23:22:30 +00:00
Building Models
< / li >
< li >
2017-01-18 01:18:15 +00:00
< a href = "recurrent.html" >
Recurrence
2017-01-17 20:06:28 +00:00
< / a >
< / li >
< / ul >
2017-03-09 00:15:25 +00:00
< a class = "edit-page" href = "https://github.com/MikeInnes/Flux.jl/tree/15b3ce1adad9e07367e68c8c3758380b02d6a4a1/docs/src/models/recurrent.md" >
2017-01-17 20:06:28 +00:00
< span class = "fa" >
< / span >
Edit on GitHub
< / a >
< / nav >
< hr / >
< / header >
< h1 >
2017-01-18 01:18:15 +00:00
< a class = "nav-anchor" id = "Recurrent-Models-1" href = "#Recurrent-Models-1" >
Recurrent Models
2017-01-17 20:06:28 +00:00
< / a >
< / h1 >
2017-01-18 01:18:15 +00:00
< p >
2017-02-02 08:19:53 +00:00
< a href = "https://en.wikipedia.org/wiki/Recurrent_neural_network" >
Recurrence
< / a >
is a first-class feature in Flux and recurrent models are very easy to build and use. Recurrences are often illustrated as cycles or self-dependencies in the graph; they can also be thought of as a hidden output from / input to the network. For example, for a sequence of inputs
< code > x1, x2, x3 ...< / code >
we produce predictions as follows:
< / p >
< pre > < code class = "language-julia" > y1 = f(W, x1) # `f` is the model, `W` represents the parameters
y2 = f(W, x2)
y3 = f(W, x3)
...< / code > < / pre >
< p >
Each evaluation is independent and the prediction made for a given input will always be the same. That makes a lot of sense for, say, MNIST images, but less sense when predicting a sequence. For that case we introduce the hidden state:
< / p >
< pre > < code class = "language-julia" > y1, s = f(W, x1, s)
y2, s = f(W, x2, s)
y3, s = f(W, x3, s)
...< / code > < / pre >
< p >
The state
< code > s< / code >
allows the prediction to depend not only on the current input
< code > x< / code >
but also on the history of past inputs.
< / p >
< p >
The simplest recurrent network looks as follows in Flux, and it should be familiar if you' ve seen the equations defining an RNN before:
< / p >
< pre > < code class = "language-julia" > @net type Recurrent
Wxy; Wyy; by
y
function (x)
y = tanh( x * Wxy + y{-1} * Wyy + by )
end
end< / code > < / pre >
< p >
The only difference from a regular feed-forward layer is that we create a variable
< code > y< / code >
which is defined as depending on itself. The
< code > y{-1}< / code >
syntax means " take the value of
< code > y< / code >
from the previous run of the network" .
< / p >
< p >
Using recurrent layers is straightforward and no different feedforard ones in terms of the
< code > Chain< / code >
macro etc. For example:
< / p >
< pre > < code class = "language-julia" > model = Chain(
Affine(784, 20), σ
Recurrent(20, 30),
Recurrent(30, 15))< / code > < / pre >
< p >
Before using the model we need to unroll it. This happens with the
< code > unroll< / code >
function:
< / p >
< pre > < code class = "language-julia" > unroll(model, 20)< / code > < / pre >
< p >
This call creates an unrolled, feed-forward version of the model which accepts N (= 20) inputs and generates N predictions at a time. Essentially, the model is replicated N times and Flux ties the hidden outputs
< code > y< / code >
to hidden inputs.
< / p >
< p >
Here' s a more complex recurrent layer, an LSTM, and again it should be familiar if you' ve seen the
< a href = "https://colah.github.io/posts/2015-08-Understanding-LSTMs/" >
equations
< / a >
:
< / p >
< pre > < code class = "language-julia" > @net type LSTM
Wxf; Wyf; bf
Wxi; Wyi; bi
Wxo; Wyo; bo
Wxc; Wyc; bc
y; state
function (x)
# Gates
forget = σ ( x * Wxf + y{-1} * Wyf + bf )
input = σ ( x * Wxi + y{-1} * Wyi + bi )
output = σ ( x * Wxo + y{-1} * Wyo + bo )
# State update and output
state′ = tanh( x * Wxc + y{-1} * Wyc + bc )
state = forget .* state{-1} + input .* state′
y = output .* tanh(state)
end
end< / code > < / pre >
< p >
The only unfamiliar part is that we have to define all of the parameters of the LSTM upfront, which adds a few lines at the beginning.
< / p >
< p >
Flux' s very mathematical notation generalises well to handling more complex models. For example,
< a href = "https://arxiv.org/abs/1409.0473" >
this neural translation model with alignment
< / a >
can be fairly straightforwardly, and recognisably, translated from the paper into Flux code:
< / p >
< pre > < code class = "language-julia" > # A recurrent model which takes a token and returns a context-dependent
# annotation.
@net type Encoder
forward
backward
token -> hcat(forward(token), backward(token))
end
Encoder(in::Integer, out::Integer) =
Encoder(LSTM(in, out÷2), flip(LSTM(in, out÷2)))
# A recurrent model which takes a sequence of annotations, attends, and returns
# a predicted output token.
@net type Decoder
attend
recur
state; y; N
function (anns)
energies = map(ann -> exp(attend(hcat(state{-1}, ann))[1]), seq(anns, N))
weights = energies./sum(energies)
ctx = sum(map((α , ann) -> α .* ann, weights, anns))
(_, state), y = recur((state{-1},y{-1}), ctx)
y
end
end
Decoder(in::Integer, out::Integer; N = 1) =
Decoder(Affine(in+out, 1),
unroll1(LSTM(in, out)),
param(zeros(1, out)), param(zeros(1, out)), N)
# The model
Nalpha = 5 # The size of the input token vector
Nphrase = 7 # The length of (padded) phrases
Nhidden = 12 # The size of the hidden state
encode = Encoder(Nalpha, Nhidden)
decode = Chain(Decoder(Nhidden, Nhidden, N = Nphrase), Affine(Nhidden, Nalpha), softmax)
model = Chain(
unroll(encode, Nphrase, stateful = false),
unroll(decode, Nphrase, stateful = false, seq = false))< / code > < / pre >
< p >
Note that this model excercises some of the more advanced parts of the compiler and isn' t stable for general use yet.
2017-01-18 01:18:15 +00:00
< / p >
2017-01-17 20:06:28 +00:00
< footer >
< hr / >
2017-02-02 07:48:56 +00:00
< a class = "previous" href = "templates.html" >
2017-01-17 20:06:28 +00:00
< span class = "direction" >
Previous
< / span >
< span class = "title" >
2017-02-02 07:48:56 +00:00
Model Templates
2017-01-18 01:18:15 +00:00
< / span >
< / a >
< a class = "next" href = "debugging.html" >
< span class = "direction" >
Next
< / span >
< span class = "title" >
Debugging
2017-01-17 20:06:28 +00:00
< / span >
< / a >
< / footer >
< / article >
< / body >
< / html >