37 lines
914 B
Julia
37 lines
914 B
Julia
using Flux, MXNet
|
|
|
|
Flux.loadmx()
|
|
|
|
conv1 = Chain(
|
|
Input(28,28),
|
|
Conv2D((5,5), out = 20), tanh,
|
|
MaxPool((2,2), stride = (2,2)))
|
|
|
|
conv2 = Chain(
|
|
conv1,
|
|
Conv2D((5,5), in = 20, out = 50), tanh,
|
|
MaxPool((2,2), stride = (2,2)))
|
|
|
|
lenet = Chain(
|
|
conv2,
|
|
flatten,
|
|
Affine(500), tanh,
|
|
Affine(10), softmax)
|
|
|
|
#--------------------------------------------------------------------------------
|
|
|
|
# Now we can continue exactly as in plain MXNet, following
|
|
# https://github.com/dmlc/MXNet.jl/blob/master/examples/mnist/lenet.jl
|
|
|
|
batch_size = 100
|
|
include(Pkg.dir("MXNet", "examples", "mnist", "mnist-data.jl"))
|
|
train_provider, eval_provider = get_mnist_providers(batch_size; flat=false)
|
|
|
|
model = mx.FeedForward(lenet)
|
|
|
|
mx.infer_shape(model.arch, data = (28, 28, 1, 100))
|
|
|
|
optimizer = mx.SGD(lr=0.05, momentum=0.9, weight_decay=0.00001)
|
|
|
|
mx.fit(model, optimizer, train_provider, n_epoch=1, eval_data=eval_provider)
|