working mnist-conv example
This commit is contained in:
parent
a56af5d16e
commit
bfb8d961e2
@ -2,9 +2,9 @@ using Flux
|
|||||||
|
|
||||||
# Flux aims to provide high-level APIs that work well across backends, but in
|
# Flux aims to provide high-level APIs that work well across backends, but in
|
||||||
# some cases you may want to take advantage of features specific to a given
|
# some cases you may want to take advantage of features specific to a given
|
||||||
# backend (or alternatively, Flux may simply not have an implementation of that
|
# backend (or Flux may simply not have an implementation of that feature yet).
|
||||||
# feature yet). In these cases it's easy to "drop down" and use the backend's
|
# In these cases it's easy to "drop down" and use the backend's API directly,
|
||||||
# API directly, where appropriate.
|
# where appropriate.
|
||||||
|
|
||||||
# In this example, both things are happening; firstly, Flux doesn't yet support
|
# In this example, both things are happening; firstly, Flux doesn't yet support
|
||||||
# ConvNets in the pure-Julia backend, but this is invisible thanks to the use of
|
# ConvNets in the pure-Julia backend, but this is invisible thanks to the use of
|
||||||
@ -12,22 +12,22 @@ using Flux
|
|||||||
# have been user-defined.
|
# have been user-defined.
|
||||||
|
|
||||||
# Secondly, we want to take advantage of TensorFlow.jl's training process and
|
# Secondly, we want to take advantage of TensorFlow.jl's training process and
|
||||||
# optimisers. We can simply call `mx.FeedForward` exactly as we would on a
|
# optimisers. We can simply call `Tensor` exactly as we would on a regular
|
||||||
# regular TensorFlow model, and the rest of the process is trivial.
|
# TensorFlow model, and the rest of the process trivially follows
|
||||||
|
# TensorFlow.jl's usual API.
|
||||||
|
|
||||||
conv1 = Chain(
|
conv1 = Chain(
|
||||||
Input(28,28),
|
Reshape(28,28,1),
|
||||||
Conv2D((5,5), out = 20), tanh,
|
Conv2D((5,5), out = 20), tanh,
|
||||||
MaxPool((2,2), stride = (2,2)))
|
MaxPool((2,2), stride = (2,2)))
|
||||||
|
|
||||||
conv2 = Chain(
|
conv2 = Chain(
|
||||||
conv1,
|
Input(12,12,20),
|
||||||
Conv2D((5,5), out = 50), tanh,
|
Conv2D((5,5), in = 20, out = 50), tanh,
|
||||||
MaxPool((2,2), stride = (2,2)))
|
MaxPool((2,2), stride = (2,2)))
|
||||||
|
|
||||||
lenet = Chain(
|
lenet = Chain(
|
||||||
conv2,
|
conv1, conv2, flatten,
|
||||||
flatten,
|
|
||||||
Dense(500), tanh,
|
Dense(500), tanh,
|
||||||
Dense(10), softmax)
|
Dense(10), softmax)
|
||||||
|
|
||||||
@ -35,13 +35,14 @@ lenet = Chain(
|
|||||||
|
|
||||||
# Now we can continue exactly as in plain TensorFlow, following
|
# Now we can continue exactly as in plain TensorFlow, following
|
||||||
# https://github.com/malmaud/TensorFlow.jl/blob/master/examples/mnist_full.jl
|
# https://github.com/malmaud/TensorFlow.jl/blob/master/examples/mnist_full.jl
|
||||||
|
# (taking only the training and cost logic, not the graph building steps)
|
||||||
|
|
||||||
using TensorFlow, Distributions
|
using TensorFlow, Distributions
|
||||||
|
|
||||||
include(Pkg.dir("TensorFlow", "examples", "mnist_loader.jl"))
|
include(Pkg.dir("TensorFlow", "examples", "mnist_loader.jl"))
|
||||||
loader = DataLoader()
|
loader = DataLoader()
|
||||||
|
|
||||||
sess = Session(Graph())
|
session = Session(Graph())
|
||||||
|
|
||||||
x = placeholder(Float32)
|
x = placeholder(Float32)
|
||||||
y′ = placeholder(Float32)
|
y′ = placeholder(Float32)
|
||||||
@ -51,21 +52,19 @@ cross_entropy = reduce_mean(-reduce_sum(y′.*log(y), reduction_indices=[2]))
|
|||||||
|
|
||||||
train_step = train.minimize(train.AdamOptimizer(1e-4), cross_entropy)
|
train_step = train.minimize(train.AdamOptimizer(1e-4), cross_entropy)
|
||||||
|
|
||||||
correct_prediction = indmax(y, 2) .== indmax(y′, 2)
|
accuracy = reduce_mean(cast(indmax(y, 2) .== indmax(y′, 2), Float32))
|
||||||
|
|
||||||
accuracy = reduce_mean(cast(correct_prediction, Float32))
|
|
||||||
|
|
||||||
run(session, initialize_all_variables())
|
run(session, initialize_all_variables())
|
||||||
|
|
||||||
for i in 1:1000
|
@progress for i in 1:1000
|
||||||
batch = next_batch(loader, 50)
|
batch = next_batch(loader, 50)
|
||||||
if i%100 == 1
|
if i%100 == 1
|
||||||
train_accuracy = run(session, accuracy, Dict(x=>batch[1], y′=>batch[2], keep_prob=>1.0))
|
train_accuracy = run(session, accuracy, Dict(x=>batch[1], y′=>batch[2]))
|
||||||
info("step $i, training accuracy $train_accuracy")
|
info("step $i, training accuracy $train_accuracy")
|
||||||
end
|
end
|
||||||
run(session, train_step, Dict(x=>batch[1], y′=>batch[2], keep_prob=>.5))
|
run(session, train_step, Dict(x=>batch[1], y′=>batch[2]))
|
||||||
end
|
end
|
||||||
|
|
||||||
testx, testy = load_test_set()
|
testx, testy = load_test_set()
|
||||||
test_accuracy = run(session, accuracy, Dict(x=>testx, y′=>testy, keep_prob=>1.0))
|
test_accuracy = run(session, accuracy, Dict(x=>testx, y′=>testy))
|
||||||
info("test accuracy $test_accuracy")
|
info("test accuracy $test_accuracy")
|
||||||
|
Loading…
Reference in New Issue
Block a user