initial conv example changes for TF

This commit is contained in:
Mike J Innes 2016-10-04 22:23:53 +01:00
parent 9e9c57d49b
commit c646ba4483

View File

@ -1,4 +1,4 @@
using Flux, MXNet using Flux
# Flux aims to provide high-level APIs that work well across backends, but in # Flux aims to provide high-level APIs that work well across backends, but in
# some cases you may want to take advantage of features specific to a given # some cases you may want to take advantage of features specific to a given
@ -8,21 +8,21 @@ using Flux, MXNet
# In this example, both things are happening; firstly, Flux doesn't yet support # In this example, both things are happening; firstly, Flux doesn't yet support
# ConvNets in the pure-Julia backend, but this is invisible thanks to the use of # ConvNets in the pure-Julia backend, but this is invisible thanks to the use of
# a simple "shim" type, `Conv`. This is provided by the library but could easily # a simple "shim" type, `Conv2D`. This is provided by the library but could easily
# have been user-defined. # have been user-defined.
# Secondly, we want to take advantage of MXNet.jl's training process and # Secondly, we want to take advantage of TensorFlow.jl's training process and
# optimisers. We can simply call `mx.FeedForward` exactly as we would on a # optimisers. We can simply call `mx.FeedForward` exactly as we would on a
# regular MXNet model, and the rest of the process is trivial. # regular TensorFlow model, and the rest of the process is trivial.
conv1 = Chain( conv1 = Chain(
Input(28,28), Input(28,28),
Conv((5,5),20), tanh, Conv2D((5,5), out = 20), tanh,
MaxPool((2,2), stride = (2,2))) MaxPool((2,2), stride = (2,2)))
conv2 = Chain( conv2 = Chain(
conv1, conv1,
Conv((5,5),50), tanh, Conv2D((5,5), out = 50), tanh,
MaxPool((2,2), stride = (2,2))) MaxPool((2,2), stride = (2,2)))
lenet = Chain( lenet = Chain(
@ -33,15 +33,18 @@ lenet = Chain(
#-------------------------------------------------------------------------------- #--------------------------------------------------------------------------------
# Now we can continue exactly as in plain MXNet, following # Now we can continue exactly as in plain TensorFlow, following
# https://github.com/dmlc/MXNet.jl/blob/master/examples/mnist/lenet.jl # https://github.com/malmaud/TensorFlow.jl/blob/master/examples/mnist_full.jl
batch_size = 100 using TensorFlow
include(Pkg.dir("MXNet", "examples", "mnist", "mnist-data.jl"))
train_provider, eval_provider = get_mnist_providers(batch_size; flat=false)
model = mx.FeedForward(lenet, context = mx.gpu()) sess = Session(Graph())
optimizer = mx.SGD(lr=0.05, momentum=0.9, weight_decay=0.00001) x = placeholder(Float64)
y = placeholder(Float64)
mx.fit(model, optimizer, train_provider, n_epoch=1, eval_data=eval_provider) y = Tensor(lenet, x)
include(Pkg.dir("TensorFlow", "examples", "mnist_loader.jl"))
loader = DataLoader()