Flux.jl/v0.2.0/examples/logreg.html
2017-05-02 13:01:23 +00:00

263 lines
7.7 KiB
HTML
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8"/>
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
<title>
Simple MNIST · Flux
</title>
<script>
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-36890222-9', 'auto');
ga('send', 'pageview');
</script>
<link href="https://cdnjs.cloudflare.com/ajax/libs/normalize/4.2.0/normalize.min.css" rel="stylesheet" type="text/css"/>
<link href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/9.5.0/styles/default.min.css" rel="stylesheet" type="text/css"/>
<link href="https://fonts.googleapis.com/css?family=Lato|Ubuntu+Mono" rel="stylesheet" type="text/css"/>
<link href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.6.3/css/font-awesome.min.css" rel="stylesheet" type="text/css"/>
<link href="../assets/documenter.css" rel="stylesheet" type="text/css"/>
<script>
documenterBaseURL=".."
</script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.2.0/require.min.js" data-main="../assets/documenter.js"></script>
<script src="../../versions.js"></script>
<link href="../../flux.css" rel="stylesheet" type="text/css"/>
</head>
<body>
<nav class="toc">
<h1>
Flux
</h1>
<form class="search" action="../search.html">
<select id="version-selector" onChange="window.location.href=this.value">
<option value="#" selected="selected" disabled="disabled">
Version
</option>
</select>
<input id="search-query" name="q" type="text" placeholder="Search docs"/>
</form>
<ul>
<li>
<a class="toctext" href="../index.html">
Home
</a>
</li>
<li>
<span class="toctext">
Building Models
</span>
<ul>
<li>
<a class="toctext" href="../models/basics.html">
Model Building Basics
</a>
</li>
<li>
<a class="toctext" href="../models/templates.html">
Model Templates
</a>
</li>
<li>
<a class="toctext" href="../models/recurrent.html">
Recurrence
</a>
</li>
<li>
<a class="toctext" href="../models/debugging.html">
Debugging
</a>
</li>
</ul>
</li>
<li>
<span class="toctext">
Other APIs
</span>
<ul>
<li>
<a class="toctext" href="../apis/batching.html">
Batching
</a>
</li>
<li>
<a class="toctext" href="../apis/backends.html">
Backends
</a>
</li>
<li>
<a class="toctext" href="../apis/storage.html">
Storing Models
</a>
</li>
</ul>
</li>
<li>
<span class="toctext">
In Action
</span>
<ul>
<li class="current">
<a class="toctext" href="logreg.html">
Simple MNIST
</a>
<ul class="internal"></ul>
</li>
<li>
<a class="toctext" href="char-rnn.html">
Char RNN
</a>
</li>
</ul>
</li>
<li>
<a class="toctext" href="../contributing.html">
Contributing &amp; Help
</a>
</li>
<li>
<a class="toctext" href="../internals.html">
Internals
</a>
</li>
</ul>
</nav>
<article id="docs">
<header>
<nav>
<ul>
<li>
In Action
</li>
<li>
<a href="logreg.html">
Simple MNIST
</a>
</li>
</ul>
<a class="edit-page" href="https://github.com/MikeInnes/Flux.jl/tree/efcb9650da31c183b94b839f66aa3467d007c33f/docs/src/examples/logreg.md">
<span class="fa">
</span>
Edit on GitHub
</a>
</nav>
<hr/>
</header>
<h1>
<a class="nav-anchor" id="Recognising-MNIST-Digits-1" href="#Recognising-MNIST-Digits-1">
Recognising MNIST Digits
</a>
</h1>
<p>
This walkthrough example will take you through writing a multi-layer perceptron that classifies MNIST digits with high accuracy.
</p>
<p>
First, we load the data using the MNIST package:
</p>
<pre><code class="language-julia">using Flux, MNIST
using Flux: accuracy
data = [(trainfeatures(i), onehot(trainlabel(i), 0:9)) for i = 1:60_000]
train = data[1:50_000]
test = data[50_001:60_000]</code></pre>
<p>
The only Flux-specific function here is
<code>onehot</code>
, which takes a class label and turns it into a one-hot-encoded vector that we can use for training. For example:
</p>
<pre><code class="language-julia">julia&gt; onehot(:b, [:a, :b, :c])
3-element Array{Int64,1}:
0
1
0</code></pre>
<p>
Otherwise, the format of the data is simple enough, it&#39;s just a list of tuples from input to output. For example:
</p>
<pre><code class="language-julia">julia&gt; data[1]
([0.0,0.0,0.0, … 0.0,0.0,0.0],[0,0,0,0,0,1,0,0,0,0])</code></pre>
<p>
<code>data[1][1]</code>
is a
<code>28*28 == 784</code>
length vector (mostly zeros due to the black background) and
<code>data[1][2]</code>
is its classification.
</p>
<p>
Now we define our model, which will simply be a function from one to the other.
</p>
<pre><code class="language-julia">m = @Chain(
Input(784),
Affine(128), relu,
Affine( 64), relu,
Affine( 10), softmax)
model = mxnet(m) # Convert to MXNet</code></pre>
<p>
We can try this out on our data already:
</p>
<pre><code class="language-julia">julia&gt; model(tobatch(data[1][1]))
10-element Array{Float64,1}:
0.10614
0.0850447
0.101474
...</code></pre>
<p>
The model gives a probability of about 0.1 to each class which is a way of saying, &quot;I have no idea&quot;. This isn&#39;t too surprising as we haven&#39;t shown it any data yet. This is easy to fix:
</p>
<pre><code class="language-julia">Flux.train!(model, train, η = 1e-3,
cb = [()-&gt;@show accuracy(m, test)])</code></pre>
<p>
The training step takes about 5 minutes (to make it faster we can do smarter things like batching). If you run this code in Juno, you&#39;ll see a progress meter, which you can hover over to see the remaining computation time.
</p>
<p>
Towards the end of the training process, Flux will have reported that the accuracy of the model is now about 90%. We can try it on our data again:
</p>
<pre><code class="language-julia">10-element Array{Float32,1}:
...
5.11423f-7
0.9354
3.1033f-5
0.000127077
...</code></pre>
<p>
Notice the class at 93%, suggesting our model is very confident about this image. We can use
<code>onecold</code>
to compare the true and predicted classes:
</p>
<pre><code class="language-julia">julia&gt; onecold(data[1][2], 0:9)
5
julia&gt; onecold(model(tobatch(data[1][1])), 0:9)
5</code></pre>
<p>
Success!
</p>
<footer>
<hr/>
<a class="previous" href="../apis/storage.html">
<span class="direction">
Previous
</span>
<span class="title">
Storing Models
</span>
</a>
<a class="next" href="char-rnn.html">
<span class="direction">
Next
</span>
<span class="title">
Char RNN
</span>
</a>
</footer>
</article>
</body>
</html>