update mnist example
This commit is contained in:
parent
f932f4bd9f
commit
f3a9934858
@ -1,6 +1,6 @@
|
||||
using Flux, MNIST
|
||||
|
||||
data = [(Vector{Float32}(trainfeatures(i)), onehot(Float32, trainlabel(i), 0:9)) for i = 1:60_000]
|
||||
data = [(trainfeatures(i), onehot(trainlabel(i), 0:9)) for i = 1:60_000]
|
||||
train = data[1:50_000]
|
||||
test = data[50_001:60_000]
|
||||
|
||||
@ -16,7 +16,7 @@ model = tf(m)
|
||||
# An example prediction pre-training
|
||||
model(data[1][1])
|
||||
|
||||
@time Flux.train!(model, train, test, η = 1e-3)
|
||||
@time Flux.train!(model, train, test, η = 1e-4)
|
||||
|
||||
# An example prediction post-training
|
||||
model(data[1][1])
|
||||
|
@ -80,7 +80,8 @@ function Flux.train!(m::Model, train, test=[]; epoch = 1, η = 0.1,
|
||||
info("Epoch $e\n")
|
||||
@progress for (x, y) in train
|
||||
y, cur_loss, _ = run(m.session, vcat(m.output, Loss, minimize_op),
|
||||
Dict(m.inputs[1]=>batchone(x), Y=>batchone(y)))
|
||||
Dict(m.inputs[1] => batchone(convertel(Float32, x)),
|
||||
Y => batchone(convertel(Float32, y))))
|
||||
if i % 5000 == 0
|
||||
@show y
|
||||
@show accuracy(m, test)
|
||||
|
@ -2,7 +2,7 @@ export AArray
|
||||
|
||||
const AArray = AbstractArray
|
||||
|
||||
initn(dims...) = randn(dims...)/10
|
||||
initn(dims...) = randn(dims...)/100
|
||||
|
||||
function train!(m, train, test = []; epoch = 1, batch = 10, η = 0.1)
|
||||
i = 0
|
||||
|
Loading…
Reference in New Issue
Block a user