Flux.jl/examples/integration-tf.jl

55 lines
1.5 KiB
Julia
Raw Normal View History

2016-10-25 15:43:59 +00:00
using Flux, Juno
2016-09-06 17:03:39 +00:00
conv1 = Chain(
2016-10-04 21:23:53 +00:00
Conv2D((5,5), out = 20), tanh,
2016-09-06 17:03:39 +00:00
MaxPool((2,2), stride = (2,2)))
conv2 = Chain(
2016-10-10 22:48:25 +00:00
Conv2D((5,5), in = 20, out = 50), tanh,
2016-09-06 17:03:39 +00:00
MaxPool((2,2), stride = (2,2)))
2017-03-17 16:34:51 +00:00
lenet = @Chain(
Input(28,28,1),
conv1, conv2,
flatten,
2016-11-14 22:16:00 +00:00
Affine(500), tanh,
Affine(10), softmax)
2016-09-06 17:03:39 +00:00
#--------------------------------------------------------------------------------
2016-10-04 21:23:53 +00:00
# Now we can continue exactly as in plain TensorFlow, following
# https://github.com/malmaud/TensorFlow.jl/blob/master/examples/mnist_full.jl
2016-10-10 22:48:25 +00:00
# (taking only the training and cost logic, not the graph building steps)
2016-09-06 17:03:39 +00:00
2016-10-04 21:50:42 +00:00
using TensorFlow, Distributions
include(Pkg.dir("TensorFlow", "examples", "mnist_loader.jl"))
loader = DataLoader()
2016-09-06 17:03:39 +00:00
2016-10-10 22:48:25 +00:00
session = Session(Graph())
2016-09-06 17:03:39 +00:00
2016-10-04 21:36:56 +00:00
x = placeholder(Float32)
y = placeholder(Float32)
2016-10-04 21:50:42 +00:00
y = Tensor(lenet, x)
2016-09-06 17:03:39 +00:00
2016-10-04 21:50:42 +00:00
cross_entropy = reduce_mean(-reduce_sum(y.*log(y), reduction_indices=[2]))
2016-10-04 21:23:53 +00:00
2016-10-04 21:50:42 +00:00
train_step = train.minimize(train.AdamOptimizer(1e-4), cross_entropy)
2016-10-04 21:23:53 +00:00
2016-10-10 22:48:25 +00:00
accuracy = reduce_mean(cast(indmax(y, 2) .== indmax(y, 2), Float32))
2016-10-04 21:50:42 +00:00
2017-04-27 11:48:11 +00:00
run(session, global_variables_initializer())
2016-10-04 21:50:42 +00:00
2016-10-10 22:48:25 +00:00
@progress for i in 1:1000
2016-10-04 21:50:42 +00:00
batch = next_batch(loader, 50)
if i%100 == 1
2016-10-10 22:48:25 +00:00
train_accuracy = run(session, accuracy, Dict(x=>batch[1], y=>batch[2]))
2016-10-04 21:50:42 +00:00
info("step $i, training accuracy $train_accuracy")
end
2016-10-10 22:48:25 +00:00
run(session, train_step, Dict(x=>batch[1], y=>batch[2]))
2016-10-04 21:50:42 +00:00
end
testx, testy = load_test_set()
2016-10-10 22:48:25 +00:00
test_accuracy = run(session, accuracy, Dict(x=>testx, y=>testy))
2016-10-04 21:50:42 +00:00
info("test accuracy $test_accuracy")