use fancy new callback api
This commit is contained in:
parent
b19e31714d
commit
edd0a1f699
@ -1,4 +1,5 @@
|
||||
using Flux, MNIST
|
||||
using Flux: accuracy
|
||||
|
||||
data = [(trainfeatures(i), onehot(trainlabel(i), 0:9)) for i = 1:60_000]
|
||||
train = data[1:50_000]
|
||||
@ -14,9 +15,10 @@ m = @Chain(
|
||||
model = mxnet(m)
|
||||
|
||||
# An example prediction pre-training
|
||||
model(unsqueeze(data[1][1]))
|
||||
model(tobatch(data[1][1]))
|
||||
|
||||
Flux.train!(model, train, test, η = 1e-4)
|
||||
Flux.train!(model, train, η = 1e-3,
|
||||
cb = [()->@show accuracy(m, test)])
|
||||
|
||||
# An example prediction post-training
|
||||
model(unsqueeze(data[1][1]))
|
||||
model(tobatch(data[1][1]))
|
||||
|
@ -14,6 +14,7 @@ alphabet = unique(input)
|
||||
N = length(alphabet)
|
||||
|
||||
train = zip(getbatches(input, alphabet), getbatches(input[2:end], alphabet))
|
||||
eval = tobatch.(first(drop(train, 5)))
|
||||
|
||||
model = Chain(
|
||||
Input(N),
|
||||
@ -24,7 +25,9 @@ model = Chain(
|
||||
|
||||
m = mxnet(unroll(model, nunroll))
|
||||
|
||||
@time Flux.train!(m, train, η = 0.1, loss = logloss)
|
||||
evalcb = () -> @show logloss(m(eval[1]), eval[2])
|
||||
|
||||
@time Flux.train!(m, train, η = 0.1, loss = logloss, cb = [evalcb])
|
||||
|
||||
function sample(model, n, temp = 1)
|
||||
s = [rand(alphabet)]
|
||||
|
Loading…
Reference in New Issue
Block a user