2017-03-01 12:37:00 +00:00
<!DOCTYPE html>
< html lang = "en" >
< head >
< meta charset = "UTF-8" / >
< meta name = "viewport" content = "width=device-width, initial-scale=1.0" / >
< title >
Recurrence · Flux
< / title >
< script >
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-36890222-9', 'auto');
ga('send', 'pageview');
< / script >
< link href = "https://cdnjs.cloudflare.com/ajax/libs/normalize/4.2.0/normalize.min.css" rel = "stylesheet" type = "text/css" / >
< link href = "https://cdnjs.cloudflare.com/ajax/libs/highlight.js/9.5.0/styles/default.min.css" rel = "stylesheet" type = "text/css" / >
< link href = "https://fonts.googleapis.com/css?family=Lato|Ubuntu+Mono" rel = "stylesheet" type = "text/css" / >
< link href = "https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.6.3/css/font-awesome.min.css" rel = "stylesheet" type = "text/css" / >
< link href = "../assets/documenter.css" rel = "stylesheet" type = "text/css" / >
< script >
documenterBaseURL=".."
< / script >
< script src = "https://cdnjs.cloudflare.com/ajax/libs/require.js/2.2.0/require.min.js" data-main = "../assets/documenter.js" > < / script >
< script src = "../../versions.js" > < / script >
< link href = "../../flux.css" rel = "stylesheet" type = "text/css" / >
< / head >
< body >
< nav class = "toc" >
< h1 >
Flux
< / h1 >
< form class = "search" action = "../search.html" >
< select id = "version-selector" onChange = "window.location.href=this.value" >
< option value = "#" selected = "selected" disabled = "disabled" >
Version
< / option >
< / select >
< input id = "search-query" name = "q" type = "text" placeholder = "Search docs" / >
< / form >
< ul >
< li >
< a class = "toctext" href = "../index.html" >
Home
< / a >
< / li >
< li >
< span class = "toctext" >
Building Models
< / span >
< ul >
< li >
< a class = "toctext" href = "basics.html" >
Model Building Basics
< / a >
< / li >
< li >
< a class = "toctext" href = "templates.html" >
Model Templates
< / a >
< / li >
< li class = "current" >
< a class = "toctext" href = "recurrent.html" >
Recurrence
< / a >
< ul class = "internal" > < / ul >
< / li >
< li >
< a class = "toctext" href = "debugging.html" >
Debugging
< / a >
< / li >
< / ul >
< / li >
< li >
< span class = "toctext" >
Other APIs
< / span >
< ul >
< li >
< a class = "toctext" href = "../apis/batching.html" >
Batching
< / a >
< / li >
< li >
< a class = "toctext" href = "../apis/backends.html" >
Backends
< / a >
< / li >
< li >
< a class = "toctext" href = "../apis/storage.html" >
Storing Models
< / a >
< / li >
< / ul >
< / li >
< li >
< span class = "toctext" >
In Action
< / span >
< ul >
< li >
< a class = "toctext" href = "../examples/logreg.html" >
2017-05-02 13:01:23 +00:00
Simple MNIST
2017-03-01 12:37:00 +00:00
< / a >
< / li >
< li >
< a class = "toctext" href = "../examples/char-rnn.html" >
Char RNN
< / a >
< / li >
< / ul >
< / li >
< li >
< a class = "toctext" href = "../contributing.html" >
Contributing & Help
< / a >
< / li >
< li >
< a class = "toctext" href = "../internals.html" >
Internals
< / a >
< / li >
< / ul >
< / nav >
< article id = "docs" >
< header >
< nav >
< ul >
< li >
Building Models
< / li >
< li >
< a href = "recurrent.html" >
Recurrence
< / a >
< / li >
< / ul >
2017-05-04 16:25:27 +00:00
< a class = "edit-page" href = "https://github.com/MikeInnes/Flux.jl/tree/7a85eff370b7c68d587b49699fa3f71e44993397/docs/src/models/recurrent.md" >
2017-03-01 12:37:00 +00:00
< span class = "fa" >
< / span >
Edit on GitHub
< / a >
< / nav >
< hr / >
< / header >
< h1 >
< a class = "nav-anchor" id = "Recurrent-Models-1" href = "#Recurrent-Models-1" >
Recurrent Models
< / a >
< / h1 >
< p >
< a href = "https://en.wikipedia.org/wiki/Recurrent_neural_network" >
Recurrence
< / a >
is a first-class feature in Flux and recurrent models are very easy to build and use. Recurrences are often illustrated as cycles or self-dependencies in the graph; they can also be thought of as a hidden output from / input to the network. For example, for a sequence of inputs
< code > x1, x2, x3 ...< / code >
we produce predictions as follows:
< / p >
< pre > < code class = "language-julia" > y1 = f(W, x1) # `f` is the model, `W` represents the parameters
y2 = f(W, x2)
y3 = f(W, x3)
...< / code > < / pre >
< p >
Each evaluation is independent and the prediction made for a given input will always be the same. That makes a lot of sense for, say, MNIST images, but less sense when predicting a sequence. For that case we introduce the hidden state:
< / p >
< pre > < code class = "language-julia" > y1, s = f(W, x1, s)
y2, s = f(W, x2, s)
y3, s = f(W, x3, s)
...< / code > < / pre >
< p >
The state
< code > s< / code >
allows the prediction to depend not only on the current input
< code > x< / code >
but also on the history of past inputs.
< / p >
< p >
The simplest recurrent network looks as follows in Flux, and it should be familiar if you' ve seen the equations defining an RNN before:
< / p >
< pre > < code class = "language-julia" > @net type Recurrent
Wxy; Wyy; by
y
function (x)
y = tanh( x * Wxy + y{-1} * Wyy + by )
end
end< / code > < / pre >
< p >
The only difference from a regular feed-forward layer is that we create a variable
< code > y< / code >
which is defined as depending on itself. The
< code > y{-1}< / code >
syntax means " take the value of
< code > y< / code >
from the previous run of the network" .
< / p >
< p >
2017-05-02 13:01:23 +00:00
Using recurrent layers is straightforward and no different feedforward ones in terms of the
2017-03-01 12:37:00 +00:00
< code > Chain< / code >
macro etc. For example:
< / p >
< pre > < code class = "language-julia" > model = Chain(
Affine(784, 20), σ
Recurrent(20, 30),
Recurrent(30, 15))< / code > < / pre >
< p >
Before using the model we need to unroll it. This happens with the
< code > unroll< / code >
function:
< / p >
< pre > < code class = "language-julia" > unroll(model, 20)< / code > < / pre >
< p >
This call creates an unrolled, feed-forward version of the model which accepts N (= 20) inputs and generates N predictions at a time. Essentially, the model is replicated N times and Flux ties the hidden outputs
< code > y< / code >
to hidden inputs.
< / p >
< p >
Here' s a more complex recurrent layer, an LSTM, and again it should be familiar if you' ve seen the
< a href = "https://colah.github.io/posts/2015-08-Understanding-LSTMs/" >
equations
< / a >
:
< / p >
< pre > < code class = "language-julia" > @net type LSTM
Wxf; Wyf; bf
Wxi; Wyi; bi
Wxo; Wyo; bo
Wxc; Wyc; bc
y; state
function (x)
# Gates
forget = σ ( x * Wxf + y{-1} * Wyf + bf )
input = σ ( x * Wxi + y{-1} * Wyi + bi )
output = σ ( x * Wxo + y{-1} * Wyo + bo )
# State update and output
state′ = tanh( x * Wxc + y{-1} * Wyc + bc )
state = forget .* state{-1} + input .* state′
y = output .* tanh(state)
end
end< / code > < / pre >
< p >
The only unfamiliar part is that we have to define all of the parameters of the LSTM upfront, which adds a few lines at the beginning.
< / p >
< footer >
< hr / >
< a class = "previous" href = "templates.html" >
< span class = "direction" >
Previous
< / span >
< span class = "title" >
Model Templates
< / span >
< / a >
< a class = "next" href = "debugging.html" >
< span class = "direction" >
Next
< / span >
< span class = "title" >
Debugging
< / span >
< / a >
< / footer >
< / article >
< / body >
< / html >