non-working example
This commit is contained in:
parent
45d30312b6
commit
2a375c4eb2
@ -36,15 +36,36 @@ lenet = Chain(
|
|||||||
# Now we can continue exactly as in plain TensorFlow, following
|
# Now we can continue exactly as in plain TensorFlow, following
|
||||||
# https://github.com/malmaud/TensorFlow.jl/blob/master/examples/mnist_full.jl
|
# https://github.com/malmaud/TensorFlow.jl/blob/master/examples/mnist_full.jl
|
||||||
|
|
||||||
using TensorFlow
|
using TensorFlow, Distributions
|
||||||
|
|
||||||
|
include(Pkg.dir("TensorFlow", "examples", "mnist_loader.jl"))
|
||||||
|
loader = DataLoader()
|
||||||
|
|
||||||
sess = Session(Graph())
|
sess = Session(Graph())
|
||||||
|
|
||||||
x = placeholder(Float32)
|
x = placeholder(Float32)
|
||||||
y′ = placeholder(Float32)
|
y′ = placeholder(Float32)
|
||||||
|
y = Tensor(lenet, x)
|
||||||
|
|
||||||
y = Tensor(lenet, x)
|
cross_entropy = reduce_mean(-reduce_sum(y′.*log(y), reduction_indices=[2]))
|
||||||
|
|
||||||
include(Pkg.dir("TensorFlow", "examples", "mnist_loader.jl"))
|
train_step = train.minimize(train.AdamOptimizer(1e-4), cross_entropy)
|
||||||
|
|
||||||
loader = DataLoader()
|
correct_prediction = indmax(y, 2) .== indmax(y′, 2)
|
||||||
|
|
||||||
|
accuracy = reduce_mean(cast(correct_prediction, Float32))
|
||||||
|
|
||||||
|
run(session, initialize_all_variables())
|
||||||
|
|
||||||
|
for i in 1:1000
|
||||||
|
batch = next_batch(loader, 50)
|
||||||
|
if i%100 == 1
|
||||||
|
train_accuracy = run(session, accuracy, Dict(x=>batch[1], y′=>batch[2], keep_prob=>1.0))
|
||||||
|
info("step $i, training accuracy $train_accuracy")
|
||||||
|
end
|
||||||
|
run(session, train_step, Dict(x=>batch[1], y′=>batch[2], keep_prob=>.5))
|
||||||
|
end
|
||||||
|
|
||||||
|
testx, testy = load_test_set()
|
||||||
|
test_accuracy = run(session, accuracy, Dict(x=>testx, y′=>testy, keep_prob=>1.0))
|
||||||
|
info("test accuracy $test_accuracy")
|
||||||
|
Loading…
Reference in New Issue
Block a user