update mnist example
This commit is contained in:
parent
f932f4bd9f
commit
f3a9934858
@ -1,6 +1,6 @@
|
|||||||
using Flux, MNIST
|
using Flux, MNIST
|
||||||
|
|
||||||
data = [(Vector{Float32}(trainfeatures(i)), onehot(Float32, trainlabel(i), 0:9)) for i = 1:60_000]
|
data = [(trainfeatures(i), onehot(trainlabel(i), 0:9)) for i = 1:60_000]
|
||||||
train = data[1:50_000]
|
train = data[1:50_000]
|
||||||
test = data[50_001:60_000]
|
test = data[50_001:60_000]
|
||||||
|
|
||||||
@ -16,7 +16,7 @@ model = tf(m)
|
|||||||
# An example prediction pre-training
|
# An example prediction pre-training
|
||||||
model(data[1][1])
|
model(data[1][1])
|
||||||
|
|
||||||
@time Flux.train!(model, train, test, η = 1e-3)
|
@time Flux.train!(model, train, test, η = 1e-4)
|
||||||
|
|
||||||
# An example prediction post-training
|
# An example prediction post-training
|
||||||
model(data[1][1])
|
model(data[1][1])
|
||||||
|
@ -80,7 +80,8 @@ function Flux.train!(m::Model, train, test=[]; epoch = 1, η = 0.1,
|
|||||||
info("Epoch $e\n")
|
info("Epoch $e\n")
|
||||||
@progress for (x, y) in train
|
@progress for (x, y) in train
|
||||||
y, cur_loss, _ = run(m.session, vcat(m.output, Loss, minimize_op),
|
y, cur_loss, _ = run(m.session, vcat(m.output, Loss, minimize_op),
|
||||||
Dict(m.inputs[1]=>batchone(x), Y=>batchone(y)))
|
Dict(m.inputs[1] => batchone(convertel(Float32, x)),
|
||||||
|
Y => batchone(convertel(Float32, y))))
|
||||||
if i % 5000 == 0
|
if i % 5000 == 0
|
||||||
@show y
|
@show y
|
||||||
@show accuracy(m, test)
|
@show accuracy(m, test)
|
||||||
|
@ -2,7 +2,7 @@ export AArray
|
|||||||
|
|
||||||
const AArray = AbstractArray
|
const AArray = AbstractArray
|
||||||
|
|
||||||
initn(dims...) = randn(dims...)/10
|
initn(dims...) = randn(dims...)/100
|
||||||
|
|
||||||
function train!(m, train, test = []; epoch = 1, batch = 10, η = 0.1)
|
function train!(m, train, test = []; epoch = 1, batch = 10, η = 0.1)
|
||||||
i = 0
|
i = 0
|
||||||
|
Loading…
Reference in New Issue
Block a user