working mnist-conv example

This commit is contained in:
Mike J Innes 2016-10-10 23:48:25 +01:00
parent a56af5d16e
commit bfb8d961e2

View File

@ -2,9 +2,9 @@ using Flux
# Flux aims to provide high-level APIs that work well across backends, but in # Flux aims to provide high-level APIs that work well across backends, but in
# some cases you may want to take advantage of features specific to a given # some cases you may want to take advantage of features specific to a given
# backend (or alternatively, Flux may simply not have an implementation of that # backend (or Flux may simply not have an implementation of that feature yet).
# feature yet). In these cases it's easy to "drop down" and use the backend's # In these cases it's easy to "drop down" and use the backend's API directly,
# API directly, where appropriate. # where appropriate.
# In this example, both things are happening; firstly, Flux doesn't yet support # In this example, both things are happening; firstly, Flux doesn't yet support
# ConvNets in the pure-Julia backend, but this is invisible thanks to the use of # ConvNets in the pure-Julia backend, but this is invisible thanks to the use of
@ -12,22 +12,22 @@ using Flux
# have been user-defined. # have been user-defined.
# Secondly, we want to take advantage of TensorFlow.jl's training process and # Secondly, we want to take advantage of TensorFlow.jl's training process and
# optimisers. We can simply call `mx.FeedForward` exactly as we would on a # optimisers. We can simply call `Tensor` exactly as we would on a regular
# regular TensorFlow model, and the rest of the process is trivial. # TensorFlow model, and the rest of the process trivially follows
# TensorFlow.jl's usual API.
conv1 = Chain( conv1 = Chain(
Input(28,28), Reshape(28,28,1),
Conv2D((5,5), out = 20), tanh, Conv2D((5,5), out = 20), tanh,
MaxPool((2,2), stride = (2,2))) MaxPool((2,2), stride = (2,2)))
conv2 = Chain( conv2 = Chain(
conv1, Input(12,12,20),
Conv2D((5,5), out = 50), tanh, Conv2D((5,5), in = 20, out = 50), tanh,
MaxPool((2,2), stride = (2,2))) MaxPool((2,2), stride = (2,2)))
lenet = Chain( lenet = Chain(
conv2, conv1, conv2, flatten,
flatten,
Dense(500), tanh, Dense(500), tanh,
Dense(10), softmax) Dense(10), softmax)
@ -35,13 +35,14 @@ lenet = Chain(
# Now we can continue exactly as in plain TensorFlow, following # Now we can continue exactly as in plain TensorFlow, following
# https://github.com/malmaud/TensorFlow.jl/blob/master/examples/mnist_full.jl # https://github.com/malmaud/TensorFlow.jl/blob/master/examples/mnist_full.jl
# (taking only the training and cost logic, not the graph building steps)
using TensorFlow, Distributions using TensorFlow, Distributions
include(Pkg.dir("TensorFlow", "examples", "mnist_loader.jl")) include(Pkg.dir("TensorFlow", "examples", "mnist_loader.jl"))
loader = DataLoader() loader = DataLoader()
sess = Session(Graph()) session = Session(Graph())
x = placeholder(Float32) x = placeholder(Float32)
y = placeholder(Float32) y = placeholder(Float32)
@ -51,21 +52,19 @@ cross_entropy = reduce_mean(-reduce_sum(y.*log(y), reduction_indices=[2]))
train_step = train.minimize(train.AdamOptimizer(1e-4), cross_entropy) train_step = train.minimize(train.AdamOptimizer(1e-4), cross_entropy)
correct_prediction = indmax(y, 2) .== indmax(y, 2) accuracy = reduce_mean(cast(indmax(y, 2) .== indmax(y, 2), Float32))
accuracy = reduce_mean(cast(correct_prediction, Float32))
run(session, initialize_all_variables()) run(session, initialize_all_variables())
for i in 1:1000 @progress for i in 1:1000
batch = next_batch(loader, 50) batch = next_batch(loader, 50)
if i%100 == 1 if i%100 == 1
train_accuracy = run(session, accuracy, Dict(x=>batch[1], y=>batch[2], keep_prob=>1.0)) train_accuracy = run(session, accuracy, Dict(x=>batch[1], y=>batch[2]))
info("step $i, training accuracy $train_accuracy") info("step $i, training accuracy $train_accuracy")
end end
run(session, train_step, Dict(x=>batch[1], y=>batch[2], keep_prob=>.5)) run(session, train_step, Dict(x=>batch[1], y=>batch[2]))
end end
testx, testy = load_test_set() testx, testy = load_test_set()
test_accuracy = run(session, accuracy, Dict(x=>testx, y=>testy, keep_prob=>1.0)) test_accuracy = run(session, accuracy, Dict(x=>testx, y=>testy))
info("test accuracy $test_accuracy") info("test accuracy $test_accuracy")