From edd0a1f6992173addc43a332740eca28790429d3 Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Mon, 1 May 2017 14:23:48 +0100 Subject: [PATCH] use fancy new callback api --- examples/MNIST.jl | 8 +++++--- examples/char-rnn.jl | 5 ++++- 2 files changed, 9 insertions(+), 4 deletions(-) diff --git a/examples/MNIST.jl b/examples/MNIST.jl index 7befebf8..88970f0f 100644 --- a/examples/MNIST.jl +++ b/examples/MNIST.jl @@ -1,4 +1,5 @@ using Flux, MNIST +using Flux: accuracy data = [(trainfeatures(i), onehot(trainlabel(i), 0:9)) for i = 1:60_000] train = data[1:50_000] @@ -14,9 +15,10 @@ m = @Chain( model = mxnet(m) # An example prediction pre-training -model(unsqueeze(data[1][1])) +model(tobatch(data[1][1])) -Flux.train!(model, train, test, η = 1e-4) +Flux.train!(model, train, η = 1e-3, + cb = [()->@show accuracy(m, test)]) # An example prediction post-training -model(unsqueeze(data[1][1])) +model(tobatch(data[1][1])) diff --git a/examples/char-rnn.jl b/examples/char-rnn.jl index 4b99f8a9..2802e232 100644 --- a/examples/char-rnn.jl +++ b/examples/char-rnn.jl @@ -14,6 +14,7 @@ alphabet = unique(input) N = length(alphabet) train = zip(getbatches(input, alphabet), getbatches(input[2:end], alphabet)) +eval = tobatch.(first(drop(train, 5))) model = Chain( Input(N), @@ -24,7 +25,9 @@ model = Chain( m = mxnet(unroll(model, nunroll)) -@time Flux.train!(m, train, η = 0.1, loss = logloss) +evalcb = () -> @show logloss(m(eval[1]), eval[2]) + +@time Flux.train!(m, train, η = 0.1, loss = logloss, cb = [evalcb]) function sample(model, n, temp = 1) s = [rand(alphabet)]