Flux.jl/examples/mnist-conv.jl

51 lines
1.5 KiB
Julia
Raw Normal View History

2016-10-04 21:23:53 +00:00
using Flux
2016-09-06 17:03:39 +00:00
# Flux aims to provide high-level APIs that work well across backends, but in
# some cases you may want to take advantage of features specific to a given
# backend (or alternatively, Flux may simply not have an implementation of that
# feature yet). In these cases it's easy to "drop down" and use the backend's
# API directly, where appropriate.
# In this example, both things are happening; firstly, Flux doesn't yet support
# ConvNets in the pure-Julia backend, but this is invisible thanks to the use of
2016-10-04 21:23:53 +00:00
# a simple "shim" type, `Conv2D`. This is provided by the library but could easily
2016-09-06 17:03:39 +00:00
# have been user-defined.
2016-10-04 21:23:53 +00:00
# Secondly, we want to take advantage of TensorFlow.jl's training process and
2016-09-06 17:03:39 +00:00
# optimisers. We can simply call `mx.FeedForward` exactly as we would on a
2016-10-04 21:23:53 +00:00
# regular TensorFlow model, and the rest of the process is trivial.
2016-09-06 17:03:39 +00:00
conv1 = Chain(
Input(28,28),
2016-10-04 21:23:53 +00:00
Conv2D((5,5), out = 20), tanh,
2016-09-06 17:03:39 +00:00
MaxPool((2,2), stride = (2,2)))
conv2 = Chain(
conv1,
2016-10-04 21:23:53 +00:00
Conv2D((5,5), out = 50), tanh,
2016-09-06 17:03:39 +00:00
MaxPool((2,2), stride = (2,2)))
lenet = Chain(
conv2,
flatten,
Dense(500), tanh,
Dense(10), softmax)
#--------------------------------------------------------------------------------
2016-10-04 21:23:53 +00:00
# Now we can continue exactly as in plain TensorFlow, following
# https://github.com/malmaud/TensorFlow.jl/blob/master/examples/mnist_full.jl
2016-09-06 17:03:39 +00:00
2016-10-04 21:23:53 +00:00
using TensorFlow
2016-09-06 17:03:39 +00:00
2016-10-04 21:23:53 +00:00
sess = Session(Graph())
2016-09-06 17:03:39 +00:00
2016-10-04 21:36:56 +00:00
x = placeholder(Float32)
y = placeholder(Float32)
2016-09-06 17:03:39 +00:00
2016-10-04 21:23:53 +00:00
y = Tensor(lenet, x)
include(Pkg.dir("TensorFlow", "examples", "mnist_loader.jl"))
loader = DataLoader()