2017-01-16 16:51:09 +00:00
<!DOCTYPE html>
2017-01-17 20:06:28 +00:00
< html lang = "en" >
< head >
< meta charset = "UTF-8" / >
< meta name = "viewport" content = "width=device-width, initial-scale=1.0" / >
< title >
2017-03-09 00:26:06 +00:00
Simple MNIST · Flux
2017-01-17 20:06:28 +00:00
< / title >
< script >
(function(i,s,o,g,r,a,m){i['GoogleAnalyticsObject']=r;i[r]=i[r]||function(){
2017-01-16 16:51:09 +00:00
(i[r].q=i[r].q||[]).push(arguments)},i[r].l=1*new Date();a=s.createElement(o),
m=s.getElementsByTagName(o)[0];a.async=1;a.src=g;m.parentNode.insertBefore(a,m)
})(window,document,'script','https://www.google-analytics.com/analytics.js','ga');
ga('create', 'UA-36890222-9', 'auto');
ga('send', 'pageview');
2017-01-17 20:06:28 +00:00
< / script >
< link href = "https://cdnjs.cloudflare.com/ajax/libs/normalize/4.2.0/normalize.min.css" rel = "stylesheet" type = "text/css" / >
< link href = "https://cdnjs.cloudflare.com/ajax/libs/highlight.js/9.5.0/styles/default.min.css" rel = "stylesheet" type = "text/css" / >
< link href = "https://fonts.googleapis.com/css?family=Lato|Ubuntu+Mono" rel = "stylesheet" type = "text/css" / >
< link href = "https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.6.3/css/font-awesome.min.css" rel = "stylesheet" type = "text/css" / >
2017-01-18 01:18:15 +00:00
< link href = "../assets/documenter.css" rel = "stylesheet" type = "text/css" / >
2017-01-17 20:06:28 +00:00
< script >
2017-01-18 01:18:15 +00:00
documenterBaseURL=".."
2017-01-17 20:06:28 +00:00
< / script >
2017-01-18 01:18:15 +00:00
< script src = "https://cdnjs.cloudflare.com/ajax/libs/require.js/2.2.0/require.min.js" data-main = "../assets/documenter.js" > < / script >
< script src = "../../versions.js" > < / script >
< link href = "../../flux.css" rel = "stylesheet" type = "text/css" / >
2017-01-17 20:06:28 +00:00
< / head >
< body >
< nav class = "toc" >
< h1 >
Flux
< / h1 >
2017-01-18 01:18:15 +00:00
< form class = "search" action = "../search.html" >
2017-01-17 20:06:28 +00:00
< select id = "version-selector" onChange = "window.location.href=this.value" >
< option value = "#" selected = "selected" disabled = "disabled" >
Version
< / option >
< / select >
< input id = "search-query" name = "q" type = "text" placeholder = "Search docs" / >
< / form >
< ul >
< li >
2017-01-18 01:18:15 +00:00
< a class = "toctext" href = "../index.html" >
2017-01-17 20:06:28 +00:00
Home
< / a >
< / li >
2017-01-18 12:45:25 +00:00
< li >
2017-01-18 23:22:30 +00:00
< span class = "toctext" >
Building Models
< / span >
< ul >
< li >
< a class = "toctext" href = "../models/basics.html" >
2017-02-01 13:48:25 +00:00
Model Building Basics
2017-01-18 23:22:30 +00:00
< / a >
< / li >
2017-02-02 07:48:56 +00:00
< li >
< a class = "toctext" href = "../models/templates.html" >
Model Templates
< / a >
< / li >
2017-01-18 23:22:30 +00:00
< li >
< a class = "toctext" href = "../models/recurrent.html" >
2017-01-18 12:45:25 +00:00
Recurrence
2017-01-18 23:22:30 +00:00
< / a >
< / li >
< li >
< a class = "toctext" href = "../models/debugging.html" >
2017-01-18 12:45:25 +00:00
Debugging
2017-01-18 23:22:30 +00:00
< / a >
< / li >
< / ul >
2017-01-18 12:45:25 +00:00
< / li >
2017-02-18 15:11:53 +00:00
< li >
2017-02-20 10:53:09 +00:00
< span class = "toctext" >
Other APIs
< / span >
< ul >
< li >
2017-02-20 11:05:06 +00:00
< a class = "toctext" href = "../apis/batching.html" >
2017-02-18 15:11:53 +00:00
Batching
2017-02-20 10:53:09 +00:00
< / a >
< / li >
< li >
2017-02-20 11:05:06 +00:00
< a class = "toctext" href = "../apis/backends.html" >
2017-02-18 15:11:53 +00:00
Backends
2017-02-20 10:53:09 +00:00
< / a >
< / li >
2017-02-28 16:50:27 +00:00
< li >
< a class = "toctext" href = "../apis/storage.html" >
Storing Models
< / a >
< / li >
2017-02-20 10:53:09 +00:00
< / ul >
2017-02-18 15:11:53 +00:00
< / li >
2017-01-17 20:06:28 +00:00
< li >
< span class = "toctext" >
2017-01-18 12:45:25 +00:00
In Action
2017-01-17 20:06:28 +00:00
< / span >
< ul >
2017-01-18 01:18:15 +00:00
< li class = "current" >
2017-01-18 12:45:25 +00:00
< a class = "toctext" href = "logreg.html" >
2017-03-09 00:26:06 +00:00
Simple MNIST
2017-01-18 01:18:15 +00:00
< / a >
< ul class = "internal" > < / ul >
< / li >
2017-02-28 16:21:45 +00:00
< li >
< a class = "toctext" href = "char-rnn.html" >
Char RNN
< / a >
< / li >
2017-01-17 20:06:28 +00:00
< / ul >
< / li >
2017-01-18 01:18:15 +00:00
< li >
< a class = "toctext" href = "../contributing.html" >
Contributing & Help
< / a >
< / li >
2017-01-18 12:45:25 +00:00
< li >
< a class = "toctext" href = "../internals.html" >
Internals
< / a >
< / li >
2017-01-17 20:06:28 +00:00
< / ul >
< / nav >
< article id = "docs" >
< header >
< nav >
< ul >
< li >
2017-01-18 12:45:25 +00:00
In Action
2017-01-17 20:06:28 +00:00
< / li >
< li >
2017-01-18 12:45:25 +00:00
< a href = "logreg.html" >
2017-03-09 00:26:06 +00:00
Simple MNIST
2017-01-17 20:06:28 +00:00
< / a >
< / li >
< / ul >
2017-07-27 21:05:23 +00:00
< a class = "edit-page" href = "https://github.com/MikeInnes/Flux.jl/tree/0e325e0425606161ded20064ba0c5d929f497fad/docs/src/examples/logreg.md" >
2017-01-17 20:06:28 +00:00
< span class = "fa" >
< / span >
Edit on GitHub
< / a >
< / nav >
< hr / >
< / header >
< h1 >
2017-03-09 00:26:06 +00:00
< a class = "nav-anchor" id = "Recognising-MNIST-Digits-1" href = "#Recognising-MNIST-Digits-1" >
Recognising MNIST Digits
2017-01-17 20:06:28 +00:00
< / a >
< / h1 >
2017-01-18 01:18:15 +00:00
< p >
2017-02-22 19:56:49 +00:00
This walkthrough example will take you through writing a multi-layer perceptron that classifies MNIST digits with high accuracy.
< / p >
< p >
First, we load the data using the MNIST package:
< / p >
< pre > < code class = "language-julia" > using Flux, MNIST
2017-05-02 12:52:47 +00:00
using Flux: accuracy
2017-02-22 19:56:49 +00:00
data = [(trainfeatures(i), onehot(trainlabel(i), 0:9)) for i = 1:60_000]
train = data[1:50_000]
test = data[50_001:60_000]< / code > < / pre >
< p >
The only Flux-specific function here is
< code > onehot< / code >
, which takes a class label and turns it into a one-hot-encoded vector that we can use for training. For example:
< / p >
< pre > < code class = "language-julia" > julia> onehot(:b, [:a, :b, :c])
3-element Array{Int64,1}:
0
1
0< / code > < / pre >
< p >
Otherwise, the format of the data is simple enough, it' s just a list of tuples from input to output. For example:
< / p >
< pre > < code class = "language-julia" > julia> data[1]
([0.0,0.0,0.0, … 0.0,0.0,0.0],[0,0,0,0,0,1,0,0,0,0])< / code > < / pre >
< p >
< code > data[1][1]< / code >
is a
< code > 28*28 == 784< / code >
length vector (mostly zeros due to the black background) and
< code > data[1][2]< / code >
is its classification.
< / p >
< p >
Now we define our model, which will simply be a function from one to the other.
< / p >
2017-05-02 12:52:47 +00:00
< pre > < code class = "language-julia" > m = @Chain(
2017-02-22 19:56:49 +00:00
Input(784),
Affine(128), relu,
Affine( 64), relu,
Affine( 10), softmax)
2017-03-09 00:26:06 +00:00
model = mxnet(m) # Convert to MXNet< / code > < / pre >
2017-02-22 19:56:49 +00:00
< p >
We can try this out on our data already:
< / p >
2017-05-02 12:52:47 +00:00
< pre > < code class = "language-julia" > julia> model(tobatch(data[1][1]))
2017-02-22 19:56:49 +00:00
10-element Array{Float64,1}:
0.10614
0.0850447
0.101474
...< / code > < / pre >
< p >
The model gives a probability of about 0.1 to each class – which is a way of saying, " I have no idea" . This isn' t too surprising as we haven' t shown it any data yet. This is easy to fix:
< / p >
2017-05-02 12:52:47 +00:00
< pre > < code class = "language-julia" > Flux.train!(model, train, η = 1e-3,
cb = [()-> @show accuracy(m, test)])< / code > < / pre >
2017-02-22 19:56:49 +00:00
< p >
The training step takes about 5 minutes (to make it faster we can do smarter things like batching). If you run this code in Juno, you' ll see a progress meter, which you can hover over to see the remaining computation time.
< / p >
< p >
Towards the end of the training process, Flux will have reported that the accuracy of the model is now about 90%. We can try it on our data again:
< / p >
< pre > < code class = "language-julia" > 10-element Array{Float32,1}:
...
5.11423f-7
0.9354
3.1033f-5
0.000127077
...< / code > < / pre >
< p >
Notice the class at 93%, suggesting our model is very confident about this image. We can use
< code > onecold< / code >
to compare the true and predicted classes:
< / p >
< pre > < code class = "language-julia" > julia> onecold(data[1][2], 0:9)
5
2017-05-02 12:52:47 +00:00
julia> onecold(model(tobatch(data[1][1])), 0:9)
2017-02-22 19:56:49 +00:00
5< / code > < / pre >
< p >
Success!
2017-01-18 01:18:15 +00:00
< / p >
2017-01-17 20:06:28 +00:00
< footer >
< hr / >
2017-02-28 16:50:27 +00:00
< a class = "previous" href = "../apis/storage.html" >
2017-01-17 20:06:28 +00:00
< span class = "direction" >
Previous
< / span >
< span class = "title" >
2017-02-28 16:50:27 +00:00
Storing Models
2017-01-18 01:18:15 +00:00
< / span >
< / a >
2017-02-28 16:21:45 +00:00
< a class = "next" href = "char-rnn.html" >
2017-01-18 01:18:15 +00:00
< span class = "direction" >
Next
< / span >
< span class = "title" >
2017-02-28 16:21:45 +00:00
Char RNN
2017-01-17 20:06:28 +00:00
< / span >
< / a >
< / footer >
< / article >
< / body >
< / html >