update examples

This commit is contained in:
Mike J Innes 2017-05-02 13:35:53 +01:00
parent d7ff193ad6
commit 3d4c8fa73b
2 changed files with 22 additions and 11 deletions

View File

@ -13,8 +13,10 @@ Firstly, we define up front how many steps we want to unroll the RNN, and the nu
nunroll = 50
nbatch = 50
getseqs(chars, alphabet) = sequences((onehot(Float32, char, alphabet) for char in chars), nunroll)
getbatches(chars, alphabet) = batches((getseqs(part, alphabet) for part in chunk(chars, nbatch))...)
getseqs(chars, alphabet) =
sequences((onehot(Float32, char, alphabet) for char in chars), nunroll)
getbatches(chars, alphabet) =
batches((getseqs(part, alphabet) for part in chunk(chars, nbatch))...)
```
Because we want the RNN to predict the next letter at each iteration, our target data is simply our input data offset by one. For example, if the input is "The quick brown fox", the target will be "he quick brown fox ". Each letter is one-hot encoded and sequences are batched together to create the training data.
@ -24,7 +26,10 @@ input = readstring("shakespeare_input.txt");
alphabet = unique(input)
N = length(alphabet)
Xs, Ys = getbatches(input, alphabet), getbatches(input[2:end], alphabet)
# An iterator of (input, output) pairs
train = zip(getbatches(input, alphabet), getbatches(input[2:end], alphabet))
# We will evaluate the loss on a particular batch to monitor the training.
eval = tobatch.(first(drop(train, 5)))
```
Creating the model and training it is straightforward:
@ -39,7 +44,11 @@ model = Chain(
m = tf(unroll(model, nunroll))
@time Flux.train!(m, Xs, Ys, η = 0.1, epoch = 1)
# Call this to see how the model is doing
evalcb = () -> @show logloss(m(eval[1]), eval[2])
@time Flux.train!(m, train, η = 0.1, loss = logloss, cb = [evalcb])
```
Finally, we can sample the model. For sampling we remove the `softmax` from the end of the chain so that we can "sharpen" the resulting probabilities.
@ -47,9 +56,9 @@ Finally, we can sample the model. For sampling we remove the `softmax` from the
```julia
function sample(model, n, temp = 1)
s = [rand(alphabet)]
m = tf(unroll(model, 1))
for i = 1:n
push!(s, wsample(alphabet, softmax(m(Seq((onehot(Float32, s[end], alphabet),)))[1]./temp)))
m = unroll1(model)
for i = 1:n-1
push!(s, wsample(alphabet, softmax(m(unsqueeze(onehot(s[end], alphabet)))./temp)[1,:]))
end
return string(s...)
end

View File

@ -6,6 +6,7 @@ First, we load the data using the MNIST package:
```julia
using Flux, MNIST
using Flux: accuracy
data = [(trainfeatures(i), onehot(trainlabel(i), 0:9)) for i = 1:60_000]
train = data[1:50_000]
@ -34,7 +35,7 @@ julia> data[1]
Now we define our model, which will simply be a function from one to the other.
```julia
m = Chain(
m = @Chain(
Input(784),
Affine(128), relu,
Affine( 64), relu,
@ -46,7 +47,7 @@ model = mxnet(m) # Convert to MXNet
We can try this out on our data already:
```julia
julia> model(data[1][1])
julia> model(tobatch(data[1][1]))
10-element Array{Float64,1}:
0.10614
0.0850447
@ -57,7 +58,8 @@ julia> model(data[1][1])
The model gives a probability of about 0.1 to each class which is a way of saying, "I have no idea". This isn't too surprising as we haven't shown it any data yet. This is easy to fix:
```julia
Flux.train!(model, train, test, η = 1e-4)
Flux.train!(model, train, η = 1e-3,
cb = [()->@show accuracy(m, test)])
```
The training step takes about 5 minutes (to make it faster we can do smarter things like batching). If you run this code in Juno, you'll see a progress meter, which you can hover over to see the remaining computation time.
@ -80,7 +82,7 @@ Notice the class at 93%, suggesting our model is very confident about this image
julia> onecold(data[1][2], 0:9)
5
julia> onecold(model(data[1][1]), 0:9)
julia> onecold(model(tobatch(data[1][1])), 0:9)
5
```