261 lines
7.7 KiB
HTML
261 lines
7.7 KiB
HTML
![]() |
<!DOCTYPE html>
|
|||
|
<html lang="en">
|
|||
|
<head>
|
|||
|
<meta charset="UTF-8"/>
|
|||
|
<meta name="viewport" content="width=device-width, initial-scale=1.0"/>
|
|||
|
<title>
|
|||
|
Logistic Regression · Flux
|
|||
|
</title>
|
|||
|
<script>
|
|||
|
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
|
|||
|
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),
|
|||
|
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
|
|||
|
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
|
|||
|
|
|||
|
ga('create', 'UA-36890222-9', 'auto');
|
|||
|
ga('send', 'pageview');
|
|||
|
|
|||
|
</script>
|
|||
|
<link href="https://cdnjs.cloudflare.com/ajax/libs/normalize/4.2.0/normalize.min.css" rel="stylesheet" type="text/css"/>
|
|||
|
<link href="https://cdnjs.cloudflare.com/ajax/libs/highlight.js/9.5.0/styles/default.min.css" rel="stylesheet" type="text/css"/>
|
|||
|
<link href="https://fonts.googleapis.com/css?family=Lato|Ubuntu+Mono" rel="stylesheet" type="text/css"/>
|
|||
|
<link href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.6.3/css/font-awesome.min.css" rel="stylesheet" type="text/css"/>
|
|||
|
<link href="../assets/documenter.css" rel="stylesheet" type="text/css"/>
|
|||
|
<script>
|
|||
|
documenterBaseURL=".."
|
|||
|
</script>
|
|||
|
<script src="https://cdnjs.cloudflare.com/ajax/libs/require.js/2.2.0/require.min.js" data-main="../assets/documenter.js"></script>
|
|||
|
<script src="../../versions.js"></script>
|
|||
|
<link href="../../flux.css" rel="stylesheet" type="text/css"/>
|
|||
|
</head>
|
|||
|
<body>
|
|||
|
<nav class="toc">
|
|||
|
<h1>
|
|||
|
Flux
|
|||
|
</h1>
|
|||
|
<form class="search" action="../search.html">
|
|||
|
<select id="version-selector" onChange="window.location.href=this.value">
|
|||
|
<option value="#" selected="selected" disabled="disabled">
|
|||
|
Version
|
|||
|
</option>
|
|||
|
</select>
|
|||
|
<input id="search-query" name="q" type="text" placeholder="Search docs"/>
|
|||
|
</form>
|
|||
|
<ul>
|
|||
|
<li>
|
|||
|
<a class="toctext" href="../index.html">
|
|||
|
Home
|
|||
|
</a>
|
|||
|
</li>
|
|||
|
<li>
|
|||
|
<span class="toctext">
|
|||
|
Building Models
|
|||
|
</span>
|
|||
|
<ul>
|
|||
|
<li>
|
|||
|
<a class="toctext" href="../models/basics.html">
|
|||
|
Model Building Basics
|
|||
|
</a>
|
|||
|
</li>
|
|||
|
<li>
|
|||
|
<a class="toctext" href="../models/templates.html">
|
|||
|
Model Templates
|
|||
|
</a>
|
|||
|
</li>
|
|||
|
<li>
|
|||
|
<a class="toctext" href="../models/recurrent.html">
|
|||
|
Recurrence
|
|||
|
</a>
|
|||
|
</li>
|
|||
|
<li>
|
|||
|
<a class="toctext" href="../models/debugging.html">
|
|||
|
Debugging
|
|||
|
</a>
|
|||
|
</li>
|
|||
|
</ul>
|
|||
|
</li>
|
|||
|
<li>
|
|||
|
<span class="toctext">
|
|||
|
Other APIs
|
|||
|
</span>
|
|||
|
<ul>
|
|||
|
<li>
|
|||
|
<a class="toctext" href="../apis/batching.html">
|
|||
|
Batching
|
|||
|
</a>
|
|||
|
</li>
|
|||
|
<li>
|
|||
|
<a class="toctext" href="../apis/backends.html">
|
|||
|
Backends
|
|||
|
</a>
|
|||
|
</li>
|
|||
|
<li>
|
|||
|
<a class="toctext" href="../apis/storage.html">
|
|||
|
Storing Models
|
|||
|
</a>
|
|||
|
</li>
|
|||
|
</ul>
|
|||
|
</li>
|
|||
|
<li>
|
|||
|
<span class="toctext">
|
|||
|
In Action
|
|||
|
</span>
|
|||
|
<ul>
|
|||
|
<li class="current">
|
|||
|
<a class="toctext" href="logreg.html">
|
|||
|
Logistic Regression
|
|||
|
</a>
|
|||
|
<ul class="internal"></ul>
|
|||
|
</li>
|
|||
|
<li>
|
|||
|
<a class="toctext" href="char-rnn.html">
|
|||
|
Char RNN
|
|||
|
</a>
|
|||
|
</li>
|
|||
|
</ul>
|
|||
|
</li>
|
|||
|
<li>
|
|||
|
<a class="toctext" href="../contributing.html">
|
|||
|
Contributing & Help
|
|||
|
</a>
|
|||
|
</li>
|
|||
|
<li>
|
|||
|
<a class="toctext" href="../internals.html">
|
|||
|
Internals
|
|||
|
</a>
|
|||
|
</li>
|
|||
|
</ul>
|
|||
|
</nav>
|
|||
|
<article id="docs">
|
|||
|
<header>
|
|||
|
<nav>
|
|||
|
<ul>
|
|||
|
<li>
|
|||
|
In Action
|
|||
|
</li>
|
|||
|
<li>
|
|||
|
<a href="logreg.html">
|
|||
|
Logistic Regression
|
|||
|
</a>
|
|||
|
</li>
|
|||
|
</ul>
|
|||
|
<a class="edit-page" href="https://github.com/MikeInnes/Flux.jl/tree/1c317eeefec910170cc72a4fe09ac54e187b3624/docs/src/examples/logreg.md">
|
|||
|
<span class="fa">
|
|||
|
|
|||
|
</span>
|
|||
|
Edit on GitHub
|
|||
|
</a>
|
|||
|
</nav>
|
|||
|
<hr/>
|
|||
|
</header>
|
|||
|
<h1>
|
|||
|
<a class="nav-anchor" id="Logistic-Regression-with-MNIST-1" href="#Logistic-Regression-with-MNIST-1">
|
|||
|
Logistic Regression with MNIST
|
|||
|
</a>
|
|||
|
</h1>
|
|||
|
<p>
|
|||
|
This walkthrough example will take you through writing a multi-layer perceptron that classifies MNIST digits with high accuracy.
|
|||
|
</p>
|
|||
|
<p>
|
|||
|
First, we load the data using the MNIST package:
|
|||
|
</p>
|
|||
|
<pre><code class="language-julia">using Flux, MNIST
|
|||
|
|
|||
|
data = [(trainfeatures(i), onehot(trainlabel(i), 0:9)) for i = 1:60_000]
|
|||
|
train = data[1:50_000]
|
|||
|
test = data[50_001:60_000]</code></pre>
|
|||
|
<p>
|
|||
|
The only Flux-specific function here is
|
|||
|
<code>onehot</code>
|
|||
|
, which takes a class label and turns it into a one-hot-encoded vector that we can use for training. For example:
|
|||
|
</p>
|
|||
|
<pre><code class="language-julia">julia> onehot(:b, [:a, :b, :c])
|
|||
|
3-element Array{Int64,1}:
|
|||
|
0
|
|||
|
1
|
|||
|
0</code></pre>
|
|||
|
<p>
|
|||
|
Otherwise, the format of the data is simple enough, it's just a list of tuples from input to output. For example:
|
|||
|
</p>
|
|||
|
<pre><code class="language-julia">julia> data[1]
|
|||
|
([0.0,0.0,0.0, … 0.0,0.0,0.0],[0,0,0,0,0,1,0,0,0,0])</code></pre>
|
|||
|
<p>
|
|||
|
<code>data[1][1]</code>
|
|||
|
is a
|
|||
|
<code>28*28 == 784</code>
|
|||
|
length vector (mostly zeros due to the black background) and
|
|||
|
<code>data[1][2]</code>
|
|||
|
is its classification.
|
|||
|
</p>
|
|||
|
<p>
|
|||
|
Now we define our model, which will simply be a function from one to the other.
|
|||
|
</p>
|
|||
|
<pre><code class="language-julia">m = Chain(
|
|||
|
Input(784),
|
|||
|
Affine(128), relu,
|
|||
|
Affine( 64), relu,
|
|||
|
Affine( 10), softmax)
|
|||
|
|
|||
|
model = tf(model)</code></pre>
|
|||
|
<p>
|
|||
|
We can try this out on our data already:
|
|||
|
</p>
|
|||
|
<pre><code class="language-julia">julia> model(data[1][1])
|
|||
|
10-element Array{Float64,1}:
|
|||
|
0.10614
|
|||
|
0.0850447
|
|||
|
0.101474
|
|||
|
...</code></pre>
|
|||
|
<p>
|
|||
|
The model gives a probability of about 0.1 to each class – which is a way of saying, "I have no idea". This isn't too surprising as we haven't shown it any data yet. This is easy to fix:
|
|||
|
</p>
|
|||
|
<pre><code class="language-julia">Flux.train!(model, train, test, η = 1e-4)</code></pre>
|
|||
|
<p>
|
|||
|
The training step takes about 5 minutes (to make it faster we can do smarter things like batching). If you run this code in Juno, you'll see a progress meter, which you can hover over to see the remaining computation time.
|
|||
|
</p>
|
|||
|
<p>
|
|||
|
Towards the end of the training process, Flux will have reported that the accuracy of the model is now about 90%. We can try it on our data again:
|
|||
|
</p>
|
|||
|
<pre><code class="language-julia">10-element Array{Float32,1}:
|
|||
|
...
|
|||
|
5.11423f-7
|
|||
|
0.9354
|
|||
|
3.1033f-5
|
|||
|
0.000127077
|
|||
|
...</code></pre>
|
|||
|
<p>
|
|||
|
Notice the class at 93%, suggesting our model is very confident about this image. We can use
|
|||
|
<code>onecold</code>
|
|||
|
to compare the true and predicted classes:
|
|||
|
</p>
|
|||
|
<pre><code class="language-julia">julia> onecold(data[1][2], 0:9)
|
|||
|
5
|
|||
|
|
|||
|
julia> onecold(model(data[1][1]), 0:9)
|
|||
|
5</code></pre>
|
|||
|
<p>
|
|||
|
Success!
|
|||
|
</p>
|
|||
|
<footer>
|
|||
|
<hr/>
|
|||
|
<a class="previous" href="../apis/storage.html">
|
|||
|
<span class="direction">
|
|||
|
Previous
|
|||
|
</span>
|
|||
|
<span class="title">
|
|||
|
Storing Models
|
|||
|
</span>
|
|||
|
</a>
|
|||
|
<a class="next" href="char-rnn.html">
|
|||
|
<span class="direction">
|
|||
|
Next
|
|||
|
</span>
|
|||
|
<span class="title">
|
|||
|
Char RNN
|
|||
|
</span>
|
|||
|
</a>
|
|||
|
</footer>
|
|||
|
</article>
|
|||
|
</body>
|
|||
|
</html>
|