resuscitate mxnet integration example
This commit is contained in:
parent
a794f068a5
commit
5f355d96ad
51
examples/integration-mx.jl
Normal file
51
examples/integration-mx.jl
Normal file
@ -0,0 +1,51 @@
|
||||
using Flux, MXNet
|
||||
|
||||
Flux.loadmx()
|
||||
|
||||
# Flux aims to provide high-level APIs that work well across backends, but in
|
||||
# some cases you may want to take advantage of features specific to a given
|
||||
# backend (or alternatively, Flux may simply not have an implementation of that
|
||||
# feature yet). In these cases it's easy to "drop down" and use the backend's
|
||||
# API directly, where appropriate.
|
||||
|
||||
# In this example, both things are happening; firstly, Flux doesn't yet support
|
||||
# ConvNets in the pure-Julia backend, but this is invisible thanks to the use of
|
||||
# a simple "shim" type, `Conv`. This is provided by the library but could easily
|
||||
# have been user-defined.
|
||||
|
||||
# Secondly, we want to take advantage of MXNet.jl's training process and
|
||||
# optimisers. We can simply call `mx.FeedForward` exactly as we would on a
|
||||
# regular MXNet model, and the rest of the process is trivial.
|
||||
|
||||
conv1 = Chain(
|
||||
Input(28,28),
|
||||
Conv2D((5,5), out = 20), tanh,
|
||||
MaxPool((2,2), stride = (2,2)))
|
||||
|
||||
conv2 = Chain(
|
||||
conv1,
|
||||
Conv2D((5,5), in = 20, out = 50), tanh,
|
||||
MaxPool((2,2), stride = (2,2)))
|
||||
|
||||
lenet = Chain(
|
||||
conv2,
|
||||
flatten,
|
||||
Affine(500), tanh,
|
||||
Affine(10), softmax)
|
||||
|
||||
#--------------------------------------------------------------------------------
|
||||
|
||||
# Now we can continue exactly as in plain MXNet, following
|
||||
# https://github.com/dmlc/MXNet.jl/blob/master/examples/mnist/lenet.jl
|
||||
|
||||
batch_size = 100
|
||||
include(Pkg.dir("MXNet", "examples", "mnist", "mnist-data.jl"))
|
||||
train_provider, eval_provider = get_mnist_providers(batch_size; flat=false)
|
||||
|
||||
model = mx.FeedForward(lenet)
|
||||
|
||||
mx.infer_shape(model.arch, data = (28, 28, 1, 100))
|
||||
|
||||
optimizer = mx.SGD(lr=0.05, momentum=0.9, weight_decay=0.00001)
|
||||
|
||||
mx.fit(model, optimizer, train_provider, n_epoch=1, eval_data=eval_provider)
|
Loading…
Reference in New Issue
Block a user