use fancy new callback api
This commit is contained in:
parent
b19e31714d
commit
edd0a1f699
@ -1,4 +1,5 @@
|
|||||||
using Flux, MNIST
|
using Flux, MNIST
|
||||||
|
using Flux: accuracy
|
||||||
|
|
||||||
data = [(trainfeatures(i), onehot(trainlabel(i), 0:9)) for i = 1:60_000]
|
data = [(trainfeatures(i), onehot(trainlabel(i), 0:9)) for i = 1:60_000]
|
||||||
train = data[1:50_000]
|
train = data[1:50_000]
|
||||||
@ -14,9 +15,10 @@ m = @Chain(
|
|||||||
model = mxnet(m)
|
model = mxnet(m)
|
||||||
|
|
||||||
# An example prediction pre-training
|
# An example prediction pre-training
|
||||||
model(unsqueeze(data[1][1]))
|
model(tobatch(data[1][1]))
|
||||||
|
|
||||||
Flux.train!(model, train, test, η = 1e-4)
|
Flux.train!(model, train, η = 1e-3,
|
||||||
|
cb = [()->@show accuracy(m, test)])
|
||||||
|
|
||||||
# An example prediction post-training
|
# An example prediction post-training
|
||||||
model(unsqueeze(data[1][1]))
|
model(tobatch(data[1][1]))
|
||||||
|
@ -14,6 +14,7 @@ alphabet = unique(input)
|
|||||||
N = length(alphabet)
|
N = length(alphabet)
|
||||||
|
|
||||||
train = zip(getbatches(input, alphabet), getbatches(input[2:end], alphabet))
|
train = zip(getbatches(input, alphabet), getbatches(input[2:end], alphabet))
|
||||||
|
eval = tobatch.(first(drop(train, 5)))
|
||||||
|
|
||||||
model = Chain(
|
model = Chain(
|
||||||
Input(N),
|
Input(N),
|
||||||
@ -24,7 +25,9 @@ model = Chain(
|
|||||||
|
|
||||||
m = mxnet(unroll(model, nunroll))
|
m = mxnet(unroll(model, nunroll))
|
||||||
|
|
||||||
@time Flux.train!(m, train, η = 0.1, loss = logloss)
|
evalcb = () -> @show logloss(m(eval[1]), eval[2])
|
||||||
|
|
||||||
|
@time Flux.train!(m, train, η = 0.1, loss = logloss, cb = [evalcb])
|
||||||
|
|
||||||
function sample(model, n, temp = 1)
|
function sample(model, n, temp = 1)
|
||||||
s = [rand(alphabet)]
|
s = [rand(alphabet)]
|
||||||
|
Loading…
Reference in New Issue
Block a user