2016-08-25 21:49:21 +00:00
|
|
|
using Flux, MNIST
|
2016-04-01 21:11:42 +00:00
|
|
|
|
2016-10-04 21:36:56 +00:00
|
|
|
data = [(trainfeatures(i), Vector{Float32}(onehot(trainlabel(i), 0:9))) for i = 1:60_000]
|
2016-09-29 20:28:53 +00:00
|
|
|
train = data[1:50_000]
|
|
|
|
test = data[50_001:60_000]
|
2016-04-01 21:11:42 +00:00
|
|
|
|
2016-08-25 16:26:52 +00:00
|
|
|
m = Chain(
|
2016-04-01 21:11:42 +00:00
|
|
|
Input(784),
|
2016-08-25 21:49:21 +00:00
|
|
|
Dense(128), relu,
|
|
|
|
Dense( 64), relu,
|
|
|
|
Dense( 10), softmax)
|
2016-04-01 21:11:42 +00:00
|
|
|
|
2016-09-29 20:28:53 +00:00
|
|
|
# Convert to TensorFlow
|
|
|
|
model = tf(m)
|
2016-08-25 16:26:52 +00:00
|
|
|
|
2016-10-04 20:11:03 +00:00
|
|
|
# An example prediction pre-training
|
|
|
|
model(data[1][1])
|
|
|
|
|
2016-09-29 20:28:53 +00:00
|
|
|
@time Flux.train!(model, train, test, η = 1e-3)
|
2016-10-04 20:11:03 +00:00
|
|
|
|
|
|
|
# An example prediction post-training
|
|
|
|
model(data[1][1])
|