sort-of working mnist example
This commit is contained in:
parent
a2aade718d
commit
8335ab8134
@ -1,11 +1,8 @@
|
|||||||
using Flux, MNIST
|
using Flux, MNIST
|
||||||
|
|
||||||
@time begin
|
data = [(trainfeatures(i), Vector{Float64}(onehot(trainlabel(i), 0:9))) for i = 1:60_000]
|
||||||
data = [(trainfeatures(i), onehot(trainlabel(i), 0:9)) for i = 1:60_000]
|
|
||||||
train = data[1:50_000]
|
train = data[1:50_000]
|
||||||
test = data[50_001:60_000]
|
test = data[50_001:60_000]
|
||||||
nothing
|
|
||||||
end
|
|
||||||
|
|
||||||
m = Chain(
|
m = Chain(
|
||||||
Input(784),
|
Input(784),
|
||||||
@ -13,6 +10,7 @@ m = Chain(
|
|||||||
Dense( 64), relu,
|
Dense( 64), relu,
|
||||||
Dense( 10), softmax)
|
Dense( 10), softmax)
|
||||||
|
|
||||||
model = mxnet(m, 784)
|
# Convert to TensorFlow
|
||||||
|
model = tf(m)
|
||||||
|
|
||||||
@time Flux.train!(model, train, test, epoch = 1, η=0.001)
|
@time Flux.train!(model, train, test, η = 1e-3)
|
||||||
|
@ -33,6 +33,9 @@ end
|
|||||||
graph(::typeof(*), args...) = *(reverse(args)...)
|
graph(::typeof(*), args...) = *(reverse(args)...)
|
||||||
graph(::typeof(+), args...) = +(args...)
|
graph(::typeof(+), args...) = +(args...)
|
||||||
graph(::typeof(softmax), x) = nn.softmax(x)
|
graph(::typeof(softmax), x) = nn.softmax(x)
|
||||||
|
graph(::typeof(relu), x) = nn.relu(x)
|
||||||
|
|
||||||
|
graph(::Input, x) = x
|
||||||
|
|
||||||
type Model
|
type Model
|
||||||
session::Session
|
session::Session
|
||||||
@ -73,12 +76,14 @@ function Flux.train!(m::Model, train, test=[]; epoch = 1, η = 0.1,
|
|||||||
Y = placeholder(Float64)
|
Y = placeholder(Float64)
|
||||||
Loss = loss(m.graph, Y)
|
Loss = loss(m.graph, Y)
|
||||||
minimize_op = TensorFlow.train.minimize(opt, Loss)
|
minimize_op = TensorFlow.train.minimize(opt, Loss)
|
||||||
run(m.session, initialize_all_variables())
|
|
||||||
for e in 1:epoch
|
for e in 1:epoch
|
||||||
info("Epoch $e\n")
|
info("Epoch $e\n")
|
||||||
@progress for (x, y) in train
|
@progress for (x, y) in train
|
||||||
y, cur_loss, _ = run(m.session, vcat(m.graph, Loss, minimize_op), Dict(m.inputs[1]=>x', Y=>y'))
|
y, cur_loss, _ = run(m.session, vcat(m.graph, Loss, minimize_op), Dict(m.inputs[1]=>x', Y=>y'))
|
||||||
i % 1000 == 0 && @show accuracy(m, test)
|
if i % 5000 == 0
|
||||||
|
@show y
|
||||||
|
@show accuracy(m, test)
|
||||||
|
end
|
||||||
i += 1
|
i += 1
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user