2016-10-25 15:43:59 +00:00
|
|
|
|
using Flux, Juno
|
2016-09-06 17:03:39 +00:00
|
|
|
|
|
|
|
|
|
conv1 = Chain(
|
2016-10-04 21:23:53 +00:00
|
|
|
|
Conv2D((5,5), out = 20), tanh,
|
2016-09-06 17:03:39 +00:00
|
|
|
|
MaxPool((2,2), stride = (2,2)))
|
|
|
|
|
|
|
|
|
|
conv2 = Chain(
|
2016-10-10 22:48:25 +00:00
|
|
|
|
Conv2D((5,5), in = 20, out = 50), tanh,
|
2016-09-06 17:03:39 +00:00
|
|
|
|
MaxPool((2,2), stride = (2,2)))
|
|
|
|
|
|
2017-03-17 16:34:51 +00:00
|
|
|
|
lenet = @Chain(
|
|
|
|
|
Input(28,28,1),
|
|
|
|
|
conv1, conv2,
|
|
|
|
|
flatten,
|
2016-11-14 22:16:00 +00:00
|
|
|
|
Affine(500), tanh,
|
|
|
|
|
Affine(10), softmax)
|
2016-09-06 17:03:39 +00:00
|
|
|
|
|
|
|
|
|
#--------------------------------------------------------------------------------
|
|
|
|
|
|
2016-10-04 21:23:53 +00:00
|
|
|
|
# Now we can continue exactly as in plain TensorFlow, following
|
|
|
|
|
# https://github.com/malmaud/TensorFlow.jl/blob/master/examples/mnist_full.jl
|
2016-10-10 22:48:25 +00:00
|
|
|
|
# (taking only the training and cost logic, not the graph building steps)
|
2016-09-06 17:03:39 +00:00
|
|
|
|
|
2016-10-04 21:50:42 +00:00
|
|
|
|
using TensorFlow, Distributions
|
|
|
|
|
|
|
|
|
|
include(Pkg.dir("TensorFlow", "examples", "mnist_loader.jl"))
|
|
|
|
|
loader = DataLoader()
|
2016-09-06 17:03:39 +00:00
|
|
|
|
|
2016-10-10 22:48:25 +00:00
|
|
|
|
session = Session(Graph())
|
2016-09-06 17:03:39 +00:00
|
|
|
|
|
2016-10-04 21:36:56 +00:00
|
|
|
|
x = placeholder(Float32)
|
|
|
|
|
y′ = placeholder(Float32)
|
2016-10-04 21:50:42 +00:00
|
|
|
|
y = Tensor(lenet, x)
|
2016-09-06 17:03:39 +00:00
|
|
|
|
|
2016-10-04 21:50:42 +00:00
|
|
|
|
cross_entropy = reduce_mean(-reduce_sum(y′.*log(y), reduction_indices=[2]))
|
2016-10-04 21:23:53 +00:00
|
|
|
|
|
2016-10-04 21:50:42 +00:00
|
|
|
|
train_step = train.minimize(train.AdamOptimizer(1e-4), cross_entropy)
|
2016-10-04 21:23:53 +00:00
|
|
|
|
|
2016-10-10 22:48:25 +00:00
|
|
|
|
accuracy = reduce_mean(cast(indmax(y, 2) .== indmax(y′, 2), Float32))
|
2016-10-04 21:50:42 +00:00
|
|
|
|
|
2017-04-27 11:48:11 +00:00
|
|
|
|
run(session, global_variables_initializer())
|
2016-10-04 21:50:42 +00:00
|
|
|
|
|
2016-10-10 22:48:25 +00:00
|
|
|
|
@progress for i in 1:1000
|
2016-10-04 21:50:42 +00:00
|
|
|
|
batch = next_batch(loader, 50)
|
|
|
|
|
if i%100 == 1
|
2016-10-10 22:48:25 +00:00
|
|
|
|
train_accuracy = run(session, accuracy, Dict(x=>batch[1], y′=>batch[2]))
|
2016-10-04 21:50:42 +00:00
|
|
|
|
info("step $i, training accuracy $train_accuracy")
|
|
|
|
|
end
|
2016-10-10 22:48:25 +00:00
|
|
|
|
run(session, train_step, Dict(x=>batch[1], y′=>batch[2]))
|
2016-10-04 21:50:42 +00:00
|
|
|
|
end
|
|
|
|
|
|
|
|
|
|
testx, testy = load_test_set()
|
2016-10-10 22:48:25 +00:00
|
|
|
|
test_accuracy = run(session, accuracy, Dict(x=>testx, y′=>testy))
|
2016-10-04 21:50:42 +00:00
|
|
|
|
info("test accuracy $test_accuracy")
|