sort-of working mnist example
This commit is contained in:
parent
a2aade718d
commit
8335ab8134
|
@ -1,11 +1,8 @@
|
|||
using Flux, MNIST
|
||||
|
||||
@time begin
|
||||
data = [(trainfeatures(i), onehot(trainlabel(i), 0:9)) for i = 1:60_000]
|
||||
train = data[1:50_000]
|
||||
test = data[50_001:60_000]
|
||||
nothing
|
||||
end
|
||||
data = [(trainfeatures(i), Vector{Float64}(onehot(trainlabel(i), 0:9))) for i = 1:60_000]
|
||||
train = data[1:50_000]
|
||||
test = data[50_001:60_000]
|
||||
|
||||
m = Chain(
|
||||
Input(784),
|
||||
|
@ -13,6 +10,7 @@ m = Chain(
|
|||
Dense( 64), relu,
|
||||
Dense( 10), softmax)
|
||||
|
||||
model = mxnet(m, 784)
|
||||
# Convert to TensorFlow
|
||||
model = tf(m)
|
||||
|
||||
@time Flux.train!(model, train, test, epoch = 1, η=0.001)
|
||||
@time Flux.train!(model, train, test, η = 1e-3)
|
||||
|
|
|
@ -33,6 +33,9 @@ end
|
|||
graph(::typeof(*), args...) = *(reverse(args)...)
|
||||
graph(::typeof(+), args...) = +(args...)
|
||||
graph(::typeof(softmax), x) = nn.softmax(x)
|
||||
graph(::typeof(relu), x) = nn.relu(x)
|
||||
|
||||
graph(::Input, x) = x
|
||||
|
||||
type Model
|
||||
session::Session
|
||||
|
@ -73,12 +76,14 @@ function Flux.train!(m::Model, train, test=[]; epoch = 1, η = 0.1,
|
|||
Y = placeholder(Float64)
|
||||
Loss = loss(m.graph, Y)
|
||||
minimize_op = TensorFlow.train.minimize(opt, Loss)
|
||||
run(m.session, initialize_all_variables())
|
||||
for e in 1:epoch
|
||||
info("Epoch $e\n")
|
||||
@progress for (x, y) in train
|
||||
y, cur_loss, _ = run(m.session, vcat(m.graph, Loss, minimize_op), Dict(m.inputs[1]=>x', Y=>y'))
|
||||
i % 1000 == 0 && @show accuracy(m, test)
|
||||
if i % 5000 == 0
|
||||
@show y
|
||||
@show accuracy(m, test)
|
||||
end
|
||||
i += 1
|
||||
end
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue