Flux.jl/examples/MNIST.jl

25 lines
510 B
Julia
Raw Normal View History

2016-08-25 21:49:21 +00:00
using Flux, MNIST
2017-05-01 13:23:48 +00:00
using Flux: accuracy
2016-04-01 21:11:42 +00:00
2017-02-02 04:39:41 +00:00
data = [(trainfeatures(i), onehot(trainlabel(i), 0:9)) for i = 1:60_000]
2016-09-29 20:28:53 +00:00
train = data[1:50_000]
test = data[50_001:60_000]
2016-04-01 21:11:42 +00:00
2017-03-17 16:34:51 +00:00
m = @Chain(
2016-04-01 21:11:42 +00:00
Input(784),
2016-11-14 22:16:00 +00:00
Affine(128), relu,
Affine( 64), relu,
Affine( 10), softmax)
2016-04-01 21:11:42 +00:00
2017-03-09 00:23:04 +00:00
# Convert to MXNet
model = mxnet(m)
2016-08-25 16:26:52 +00:00
2016-10-04 20:11:03 +00:00
# An example prediction pre-training
2017-05-01 13:23:48 +00:00
model(tobatch(data[1][1]))
2016-10-04 20:11:03 +00:00
2017-05-01 13:23:48 +00:00
Flux.train!(model, train, η = 1e-3,
cb = [()->@show accuracy(m, test)])
2016-10-04 20:11:03 +00:00
# An example prediction post-training
2017-05-01 13:23:48 +00:00
model(tobatch(data[1][1]))