From 7cfb107da709091a62a95b75f265531203b1eeae Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Thu, 25 Aug 2016 17:26:52 +0100 Subject: [PATCH] run with what we have --- examples/MNIST.jl | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/examples/MNIST.jl b/examples/MNIST.jl index c7dcfa4c..d6f705b7 100644 --- a/examples/MNIST.jl +++ b/examples/MNIST.jl @@ -1,13 +1,20 @@ -using Flux, MNIST +using Flux, MNIST, Flow, MacroTools +import Flux.MX: mxnet +import Flux: back!, update!, graph -const data = collect(zip([trainfeatures(i) for i = 1:60_000], - [onehot(trainlabel(i), 1:10) for i = 1:60_000])) -const train = data[1:50_000] -const test = data[50_001:60_000] +@time begin + const data = [(trainfeatures(i), onehot(trainlabel(i), 0:9)) for i = 1:60_000] + const train = data[1:50_000] + const test = data[50_001:60_000] + nothing +end -const m = Sequence( +m = Chain( Input(784), - Dense(30), Sigmoid(), - Dense(10), Sigmoid()) + Dense(784, 128), relu, + Dense(128, 64), relu, + Dense(64, 10), softmax) -@time Flux.train!(m, train, test, epoch = 30) +model = mxnet(m, 784) + +@time Flux.train!(model, train, test, epoch = 1, η=0.001)