use fancy new callback api

This commit is contained in:
Mike J Innes 2017-05-01 14:23:48 +01:00
parent b19e31714d
commit edd0a1f699
2 changed files with 9 additions and 4 deletions

View File

@ -1,4 +1,5 @@
using Flux, MNIST
using Flux: accuracy
data = [(trainfeatures(i), onehot(trainlabel(i), 0:9)) for i = 1:60_000]
train = data[1:50_000]
@ -14,9 +15,10 @@ m = @Chain(
model = mxnet(m)
# An example prediction pre-training
model(unsqueeze(data[1][1]))
model(tobatch(data[1][1]))
Flux.train!(model, train, test, η = 1e-4)
Flux.train!(model, train, η = 1e-3,
cb = [()->@show accuracy(m, test)])
# An example prediction post-training
model(unsqueeze(data[1][1]))
model(tobatch(data[1][1]))

View File

@ -14,6 +14,7 @@ alphabet = unique(input)
N = length(alphabet)
train = zip(getbatches(input, alphabet), getbatches(input[2:end], alphabet))
eval = tobatch.(first(drop(train, 5)))
model = Chain(
Input(N),
@ -24,7 +25,9 @@ model = Chain(
m = mxnet(unroll(model, nunroll))
@time Flux.train!(m, train, η = 0.1, loss = logloss)
evalcb = () -> @show logloss(m(eval[1]), eval[2])
@time Flux.train!(m, train, η = 0.1, loss = logloss, cb = [evalcb])
function sample(model, n, temp = 1)
s = [rand(alphabet)]