456 lines
14 KiB
HTML
456 lines
14 KiB
HTML
<!DOCTYPE html>
|
||
<html lang="en">
|
||
<head>
|
||
<meta charset="UTF-8"/>
|
||
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
||
<title>
|
||
Model Building Basics · Flux
|
||
</title>
|
||
<script>
|
||
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
|
||
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),
|
||
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
|
||
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
|
||
|
||
ga('create', 'UA-36890222-9', 'auto');
|
||
ga('send', 'pageview');
|
||
|
||
</script>
|
||
<link href="https://cdnjs.cloudflare.com/ajax/libs/normalize/4.2.0/normalize.min.css" rel="stylesheet" type="text/css"/>
|
||
<link href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/9.5.0/styles/default.min.css" rel="stylesheet" type="text/css"/>
|
||
<link href="https://fonts.googleapis.com/css?family=Lato|Ubuntu+Mono" rel="stylesheet" type="text/css"/>
|
||
<link href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.6.3/css/font-awesome.min.css" rel="stylesheet" type="text/css"/>
|
||
<link href="../assets/documenter.css" rel="stylesheet" type="text/css"/>
|
||
<script>
|
||
documenterBaseURL=".."
|
||
</script>
|
||
<script src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.2.0/require.min.js" data-main="../assets/documenter.js"></script>
|
||
<script src="../../versions.js"></script>
|
||
<link href="../../flux.css" rel="stylesheet" type="text/css"/>
|
||
</head>
|
||
<body>
|
||
<nav class="toc">
|
||
<h1>
|
||
Flux
|
||
</h1>
|
||
<form class="search" action="../search.html">
|
||
<select id="version-selector" onChange="window.location.href=this.value">
|
||
<option value="#" selected="selected" disabled="disabled">
|
||
Version
|
||
</option>
|
||
</select>
|
||
<input id="search-query" name="q" type="text" placeholder="Search docs"/>
|
||
</form>
|
||
<ul>
|
||
<li>
|
||
<a class="toctext" href="../index.html">
|
||
Home
|
||
</a>
|
||
</li>
|
||
<li>
|
||
<span class="toctext">
|
||
Building Models
|
||
</span>
|
||
<ul>
|
||
<li class="current">
|
||
<a class="toctext" href="basics.html">
|
||
Model Building Basics
|
||
</a>
|
||
<ul class="internal">
|
||
<li>
|
||
<a class="toctext" href="#Net-Functions-1">
|
||
Net Functions
|
||
</a>
|
||
</li>
|
||
<li>
|
||
<a class="toctext" href="#The-Model-1">
|
||
The Model
|
||
</a>
|
||
</li>
|
||
<li>
|
||
<a class="toctext" href="#Parameters-1">
|
||
Parameters
|
||
</a>
|
||
</li>
|
||
<li>
|
||
<a class="toctext" href="#Layers-1">
|
||
Layers
|
||
</a>
|
||
</li>
|
||
<li>
|
||
<a class="toctext" href="#Combining-Layers-1">
|
||
Combining Layers
|
||
</a>
|
||
</li>
|
||
<li>
|
||
<a class="toctext" href="#Dressed-like-a-model-1">
|
||
Dressed like a model
|
||
</a>
|
||
</li>
|
||
</ul>
|
||
</li>
|
||
<li>
|
||
<a class="toctext" href="templates.html">
|
||
Model Templates
|
||
</a>
|
||
</li>
|
||
<li>
|
||
<a class="toctext" href="recurrent.html">
|
||
Recurrence
|
||
</a>
|
||
</li>
|
||
<li>
|
||
<a class="toctext" href="debugging.html">
|
||
Debugging
|
||
</a>
|
||
</li>
|
||
</ul>
|
||
</li>
|
||
<li>
|
||
<span class="toctext">
|
||
Other APIs
|
||
</span>
|
||
<ul>
|
||
<li>
|
||
<a class="toctext" href="../apis/batching.html">
|
||
Batching
|
||
</a>
|
||
</li>
|
||
<li>
|
||
<a class="toctext" href="../apis/backends.html">
|
||
Backends
|
||
</a>
|
||
</li>
|
||
<li>
|
||
<a class="toctext" href="../apis/storage.html">
|
||
Storing Models
|
||
</a>
|
||
</li>
|
||
</ul>
|
||
</li>
|
||
<li>
|
||
<span class="toctext">
|
||
In Action
|
||
</span>
|
||
<ul>
|
||
<li>
|
||
<a class="toctext" href="../examples/logreg.html">
|
||
Simple MNIST
|
||
</a>
|
||
</li>
|
||
<li>
|
||
<a class="toctext" href="../examples/char-rnn.html">
|
||
Char RNN
|
||
</a>
|
||
</li>
|
||
</ul>
|
||
</li>
|
||
<li>
|
||
<a class="toctext" href="../contributing.html">
|
||
Contributing & Help
|
||
</a>
|
||
</li>
|
||
<li>
|
||
<a class="toctext" href="../internals.html">
|
||
Internals
|
||
</a>
|
||
</li>
|
||
</ul>
|
||
</nav>
|
||
<article id="docs">
|
||
<header>
|
||
<nav>
|
||
<ul>
|
||
<li>
|
||
Building Models
|
||
</li>
|
||
<li>
|
||
<a href="basics.html">
|
||
Model Building Basics
|
||
</a>
|
||
</li>
|
||
</ul>
|
||
<a class="edit-page" href="https://github.com/MikeInnes/Flux.jl/tree/7a85eff370b7c68d587b49699fa3f71e44993397/docs/src/models/basics.md">
|
||
<span class="fa">
|
||
|
||
</span>
|
||
Edit on GitHub
|
||
</a>
|
||
</nav>
|
||
<hr/>
|
||
</header>
|
||
<h1>
|
||
<a class="nav-anchor" id="Model-Building-Basics-1" href="#Model-Building-Basics-1">
|
||
Model Building Basics
|
||
</a>
|
||
</h1>
|
||
<h2>
|
||
<a class="nav-anchor" id="Net-Functions-1" href="#Net-Functions-1">
|
||
Net Functions
|
||
</a>
|
||
</h2>
|
||
<p>
|
||
Flux's core feature is the
|
||
<code>@net</code>
|
||
macro, which adds some superpowers to regular ol' Julia functions. Consider this simple function with the
|
||
<code>@net</code>
|
||
annotation applied:
|
||
</p>
|
||
<pre><code class="language-julia">@net f(x) = x .* x
|
||
f([1,2,3]) == [1,4,9]</code></pre>
|
||
<p>
|
||
This behaves as expected, but we have some extra features. For example, we can convert the function to run on
|
||
<a href="https://www.tensorflow.org/">
|
||
TensorFlow
|
||
</a>
|
||
or
|
||
<a href="https://github.com/dmlc/MXNet.jl">
|
||
MXNet
|
||
</a>
|
||
:
|
||
</p>
|
||
<pre><code class="language-julia">f_mxnet = mxnet(f)
|
||
f_mxnet([1,2,3]) == [1.0, 4.0, 9.0]</code></pre>
|
||
<p>
|
||
Simples! Flux took care of a lot of boilerplate for us and just ran the multiplication on MXNet. MXNet can optimise this code for us, taking advantage of parallelism or running the code on a GPU.
|
||
</p>
|
||
<p>
|
||
Using MXNet, we can get the gradient of the function, too:
|
||
</p>
|
||
<pre><code class="language-julia">back!(f_mxnet, [1,1,1], [1,2,3]) == ([2.0, 4.0, 6.0],)</code></pre>
|
||
<p>
|
||
<code>f</code>
|
||
is effectively
|
||
<code>x^2</code>
|
||
, so the gradient is
|
||
<code>2x</code>
|
||
as expected.
|
||
</p>
|
||
<h2>
|
||
<a class="nav-anchor" id="The-Model-1" href="#The-Model-1">
|
||
The Model
|
||
</a>
|
||
</h2>
|
||
<p>
|
||
The core concept in Flux is the
|
||
<em>
|
||
model
|
||
</em>
|
||
. This corresponds to what might be called a "layer" or "module" in other frameworks. A model is simply a differentiable function with parameters. Given a model
|
||
<code>m</code>
|
||
we can do things like:
|
||
</p>
|
||
<pre><code class="language-julia">m(x) # See what the model does to an input vector `x`
|
||
back!(m, Δ, x) # backpropogate the gradient `Δ` through `m`
|
||
update!(m, η) # update the parameters of `m` using the gradient</code></pre>
|
||
<p>
|
||
We can implement a model however we like as long as it fits this interface. But as hinted above,
|
||
<code>@net</code>
|
||
is a particularly easy way to do it, because it gives you these functions for free.
|
||
</p>
|
||
<h2>
|
||
<a class="nav-anchor" id="Parameters-1" href="#Parameters-1">
|
||
Parameters
|
||
</a>
|
||
</h2>
|
||
<p>
|
||
Consider how we'd write a logistic regression. We just take the Julia code and add
|
||
<code>@net</code>
|
||
.
|
||
</p>
|
||
<pre><code class="language-julia">@net logistic(W, b, x) = softmax(x * W .+ b)
|
||
|
||
W = randn(10, 2)
|
||
b = randn(1, 2)
|
||
x = rand(1, 10) # [0.563 0.346 0.780 …] – fake data
|
||
y = [1 0] # our desired classification of `x`
|
||
|
||
ŷ = logistic(W, b, x) # [0.46 0.54]</code></pre>
|
||
<p>
|
||
The network takes a set of 10 features (
|
||
<code>x</code>
|
||
, a row vector) and produces a classification
|
||
<code>ŷ</code>
|
||
, equivalent to a probability of true vs false.
|
||
<code>softmax</code>
|
||
scales the output to sum to one, so that we can interpret it as a probability distribution.
|
||
</p>
|
||
<p>
|
||
We can use MXNet and get gradients:
|
||
</p>
|
||
<pre><code class="language-julia">logisticm = mxnet(logistic)
|
||
logisticm(W, b, x) # [0.46 0.54]
|
||
back!(logisticm, [0.1 -0.1], W, b, x) # (dW, db, dx)</code></pre>
|
||
<p>
|
||
The gradient
|
||
<code>[0.1 -0.1]</code>
|
||
says that we want to increase
|
||
<code>ŷ[1]</code>
|
||
and decrease
|
||
<code>ŷ[2]</code>
|
||
to get closer to
|
||
<code>y</code>
|
||
.
|
||
<code>back!</code>
|
||
gives us the tweaks we need to make to each input (
|
||
<code>W</code>
|
||
,
|
||
<code>b</code>
|
||
,
|
||
<code>x</code>
|
||
) in order to do this. If we add these tweaks to
|
||
<code>W</code>
|
||
and
|
||
<code>b</code>
|
||
it will predict
|
||
<code>ŷ</code>
|
||
more accurately.
|
||
</p>
|
||
<p>
|
||
Treating parameters like
|
||
<code>W</code>
|
||
and
|
||
<code>b</code>
|
||
as inputs can get unwieldy in larger networks. Since they are both global we can use them directly:
|
||
</p>
|
||
<pre><code class="language-julia">@net logistic(x) = softmax(x * W .+ b)</code></pre>
|
||
<p>
|
||
However, this gives us a problem: how do we get their gradients?
|
||
</p>
|
||
<p>
|
||
Flux solves this with the
|
||
<code>Param</code>
|
||
wrapper:
|
||
</p>
|
||
<pre><code class="language-julia">W = param(randn(10, 2))
|
||
b = param(randn(1, 2))
|
||
@net logistic(x) = softmax(x * W .+ b)</code></pre>
|
||
<p>
|
||
This works as before, but now
|
||
<code>W.x</code>
|
||
stores the real value and
|
||
<code>W.Δx</code>
|
||
stores its gradient, so we don't have to manage it by hand. We can even use
|
||
<code>update!</code>
|
||
to apply the gradients automatically.
|
||
</p>
|
||
<pre><code class="language-julia">logisticm(x) # [0.46, 0.54]
|
||
|
||
back!(logisticm, [-1 1], x)
|
||
update!(logisticm, 0.1)
|
||
|
||
logisticm(x) # [0.51, 0.49]</code></pre>
|
||
<p>
|
||
Our network got a little closer to the target
|
||
<code>y</code>
|
||
. Now we just need to repeat this millions of times.
|
||
</p>
|
||
<p>
|
||
<em>
|
||
Side note:
|
||
</em>
|
||
We obviously need a way to calculate the "tweak"
|
||
<code>[0.1, -0.1]</code>
|
||
automatically. We can use a loss function like
|
||
<em>
|
||
mean squared error
|
||
</em>
|
||
for this:
|
||
</p>
|
||
<pre><code class="language-julia"># How wrong is ŷ?
|
||
mse([0.46, 0.54], [1, 0]) == 0.292
|
||
# What change to `ŷ` will reduce the wrongness?
|
||
back!(mse, -1, [0.46, 0.54], [1, 0]) == [0.54 -0.54]</code></pre>
|
||
<h2>
|
||
<a class="nav-anchor" id="Layers-1" href="#Layers-1">
|
||
Layers
|
||
</a>
|
||
</h2>
|
||
<p>
|
||
Bigger networks contain many affine transformations like
|
||
<code>W * x + b</code>
|
||
. We don't want to write out the definition every time we use it. Instead, we can factor this out by making a function that produces models:
|
||
</p>
|
||
<pre><code class="language-julia">function create_affine(in, out)
|
||
W = param(randn(out,in))
|
||
b = param(randn(out))
|
||
@net x -> W * x + b
|
||
end
|
||
|
||
affine1 = create_affine(3,2)
|
||
affine1([1,2,3])</code></pre>
|
||
<p>
|
||
Flux has a
|
||
<a href="templates.html">
|
||
more powerful syntax
|
||
</a>
|
||
for this pattern, but also provides a bunch of layers out of the box. So we can instead write:
|
||
</p>
|
||
<pre><code class="language-julia">affine1 = Affine(5, 5)
|
||
affine2 = Affine(5, 5)
|
||
|
||
softmax(affine1(x)) # [0.167952 0.186325 0.176683 0.238571 0.23047]
|
||
softmax(affine2(x)) # [0.125361 0.246448 0.21966 0.124596 0.283935]</code></pre>
|
||
<h2>
|
||
<a class="nav-anchor" id="Combining-Layers-1" href="#Combining-Layers-1">
|
||
Combining Layers
|
||
</a>
|
||
</h2>
|
||
<p>
|
||
A more complex model usually involves many basic layers like
|
||
<code>affine</code>
|
||
, where we use the output of one layer as the input to the next:
|
||
</p>
|
||
<pre><code class="language-julia">mymodel1(x) = softmax(affine2(σ(affine1(x))))
|
||
mymodel1(x1) # [0.187935, 0.232237, 0.169824, 0.230589, 0.179414]</code></pre>
|
||
<p>
|
||
This syntax is again a little unwieldy for larger networks, so Flux provides another template of sorts to create the function for us:
|
||
</p>
|
||
<pre><code class="language-julia">mymodel2 = Chain(affine1, σ, affine2, softmax)
|
||
mymodel2(x2) # [0.187935, 0.232237, 0.169824, 0.230589, 0.179414]</code></pre>
|
||
<p>
|
||
<code>mymodel2</code>
|
||
is exactly equivalent to
|
||
<code>mymodel1</code>
|
||
because it simply calls the provided functions in sequence. We don't have to predefine the affine layers and can also write this as:
|
||
</p>
|
||
<pre><code class="language-julia">mymodel3 = Chain(
|
||
Affine(5, 5), σ,
|
||
Affine(5, 5), softmax)</code></pre>
|
||
<h2>
|
||
<a class="nav-anchor" id="Dressed-like-a-model-1" href="#Dressed-like-a-model-1">
|
||
Dressed like a model
|
||
</a>
|
||
</h2>
|
||
<p>
|
||
We noted above that a model is a function with trainable parameters. Normal functions like
|
||
<code>exp</code>
|
||
are actually models too – they just happen to have 0 parameters. Flux doesn't care, and anywhere that you use one, you can use the other. For example,
|
||
<code>Chain</code>
|
||
will happily work with regular functions:
|
||
</p>
|
||
<pre><code class="language-julia">foo = Chain(exp, sum, log)
|
||
foo([1,2,3]) == 3.408 == log(sum(exp([1,2,3])))</code></pre>
|
||
<footer>
|
||
<hr/>
|
||
<a class="previous" href="../index.html">
|
||
<span class="direction">
|
||
Previous
|
||
</span>
|
||
<span class="title">
|
||
Home
|
||
</span>
|
||
</a>
|
||
<a class="next" href="templates.html">
|
||
<span class="direction">
|
||
Next
|
||
</span>
|
||
<span class="title">
|
||
Model Templates
|
||
</span>
|
||
</a>
|
||
</footer>
|
||
</article>
|
||
</body>
|
||
</html>
|