mnist walkthrough
This commit is contained in:
parent
6f563b6cb7
commit
4a9517b23d
@ -1,3 +1,87 @@
|
|||||||
# Logistic Regression with MNIST
|
# Logistic Regression with MNIST
|
||||||
|
|
||||||
[WIP]
|
This walkthrough example will take you through writing a multi-layer perceptron that classifies MNIST digits with high accuracy.
|
||||||
|
|
||||||
|
First, we load the data using the MNIST package:
|
||||||
|
|
||||||
|
```julia
|
||||||
|
using Flux, MNIST
|
||||||
|
|
||||||
|
data = [(trainfeatures(i), onehot(trainlabel(i), 0:9)) for i = 1:60_000]
|
||||||
|
train = data[1:50_000]
|
||||||
|
test = data[50_001:60_000]
|
||||||
|
```
|
||||||
|
|
||||||
|
The only Flux-specific function here is `onehot`, which takes a class label and turns it into a one-hot-encoded vector that we can use for training. For example:
|
||||||
|
|
||||||
|
```julia
|
||||||
|
julia> onehot(:b, [:a, :b, :c])
|
||||||
|
3-element Array{Int64,1}:
|
||||||
|
0
|
||||||
|
1
|
||||||
|
0
|
||||||
|
```
|
||||||
|
|
||||||
|
Otherwise, the format of the data is simple enough, it's just a list of tuples from input to output. For example:
|
||||||
|
|
||||||
|
```julia
|
||||||
|
julia> data[1]
|
||||||
|
([0.0,0.0,0.0, … 0.0,0.0,0.0],[0,0,0,0,0,1,0,0,0,0])
|
||||||
|
```
|
||||||
|
|
||||||
|
`data[1][1]` is a `28*28 == 784` length vector (mostly zeros due to the black background) and `data[1][2]` is its classification.
|
||||||
|
|
||||||
|
Now we define our model, which will simply be a function from one to the other.
|
||||||
|
|
||||||
|
```julia
|
||||||
|
m = Chain(
|
||||||
|
Input(784),
|
||||||
|
Affine(128), relu,
|
||||||
|
Affine( 64), relu,
|
||||||
|
Affine( 10), softmax)
|
||||||
|
|
||||||
|
model = tf(model)
|
||||||
|
```
|
||||||
|
|
||||||
|
We can try this out on our data already:
|
||||||
|
|
||||||
|
```julia
|
||||||
|
julia> model(data[1][1])
|
||||||
|
10-element Array{Float64,1}:
|
||||||
|
0.10614
|
||||||
|
0.0850447
|
||||||
|
0.101474
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
The model gives a probability of about 0.1 to each class – which is a way of saying, "I have no idea". This isn't too surprising as we haven't shown it any data yet. This is easy to fix:
|
||||||
|
|
||||||
|
```julia
|
||||||
|
Flux.train!(model, train, test, η = 1e-4)
|
||||||
|
```
|
||||||
|
|
||||||
|
The training step takes about 5 minutes (to make it faster we can do smarter things like batching). If you run this code in Juno, you'll see a progress meter, which you can hover over to see the remaining computation time.
|
||||||
|
|
||||||
|
Towards the end of the training process, Flux will have reported that the accuracy of the model is now about 90%. We can try it on our data again:
|
||||||
|
|
||||||
|
```julia
|
||||||
|
10-element Array{Float32,1}:
|
||||||
|
...
|
||||||
|
5.11423f-7
|
||||||
|
0.9354
|
||||||
|
3.1033f-5
|
||||||
|
0.000127077
|
||||||
|
...
|
||||||
|
```
|
||||||
|
|
||||||
|
Notice the class at 93%, suggesting our model is very confident about this image. We can use `onecold` to compare the true and predicted classes:
|
||||||
|
|
||||||
|
```julia
|
||||||
|
julia> onecold(data[1][2], 0:9)
|
||||||
|
5
|
||||||
|
|
||||||
|
julia> onecold(model(data[1][1]), 0:9)
|
||||||
|
5
|
||||||
|
```
|
||||||
|
|
||||||
|
Success!
|
||||||
|
Loading…
Reference in New Issue
Block a user