update examples
This commit is contained in:
parent
d7ff193ad6
commit
3d4c8fa73b
|
@ -13,8 +13,10 @@ Firstly, we define up front how many steps we want to unroll the RNN, and the nu
|
|||
nunroll = 50
|
||||
nbatch = 50
|
||||
|
||||
getseqs(chars, alphabet) = sequences((onehot(Float32, char, alphabet) for char in chars), nunroll)
|
||||
getbatches(chars, alphabet) = batches((getseqs(part, alphabet) for part in chunk(chars, nbatch))...)
|
||||
getseqs(chars, alphabet) =
|
||||
sequences((onehot(Float32, char, alphabet) for char in chars), nunroll)
|
||||
getbatches(chars, alphabet) =
|
||||
batches((getseqs(part, alphabet) for part in chunk(chars, nbatch))...)
|
||||
```
|
||||
|
||||
Because we want the RNN to predict the next letter at each iteration, our target data is simply our input data offset by one. For example, if the input is "The quick brown fox", the target will be "he quick brown fox ". Each letter is one-hot encoded and sequences are batched together to create the training data.
|
||||
|
@ -24,7 +26,10 @@ input = readstring("shakespeare_input.txt");
|
|||
alphabet = unique(input)
|
||||
N = length(alphabet)
|
||||
|
||||
Xs, Ys = getbatches(input, alphabet), getbatches(input[2:end], alphabet)
|
||||
# An iterator of (input, output) pairs
|
||||
train = zip(getbatches(input, alphabet), getbatches(input[2:end], alphabet))
|
||||
# We will evaluate the loss on a particular batch to monitor the training.
|
||||
eval = tobatch.(first(drop(train, 5)))
|
||||
```
|
||||
|
||||
Creating the model and training it is straightforward:
|
||||
|
@ -39,7 +44,11 @@ model = Chain(
|
|||
|
||||
m = tf(unroll(model, nunroll))
|
||||
|
||||
@time Flux.train!(m, Xs, Ys, η = 0.1, epoch = 1)
|
||||
# Call this to see how the model is doing
|
||||
evalcb = () -> @show logloss(m(eval[1]), eval[2])
|
||||
|
||||
@time Flux.train!(m, train, η = 0.1, loss = logloss, cb = [evalcb])
|
||||
|
||||
```
|
||||
|
||||
Finally, we can sample the model. For sampling we remove the `softmax` from the end of the chain so that we can "sharpen" the resulting probabilities.
|
||||
|
@ -47,9 +56,9 @@ Finally, we can sample the model. For sampling we remove the `softmax` from the
|
|||
```julia
|
||||
function sample(model, n, temp = 1)
|
||||
s = [rand(alphabet)]
|
||||
m = tf(unroll(model, 1))
|
||||
for i = 1:n
|
||||
push!(s, wsample(alphabet, softmax(m(Seq((onehot(Float32, s[end], alphabet),)))[1]./temp)))
|
||||
m = unroll1(model)
|
||||
for i = 1:n-1
|
||||
push!(s, wsample(alphabet, softmax(m(unsqueeze(onehot(s[end], alphabet)))./temp)[1,:]))
|
||||
end
|
||||
return string(s...)
|
||||
end
|
||||
|
|
|
@ -6,6 +6,7 @@ First, we load the data using the MNIST package:
|
|||
|
||||
```julia
|
||||
using Flux, MNIST
|
||||
using Flux: accuracy
|
||||
|
||||
data = [(trainfeatures(i), onehot(trainlabel(i), 0:9)) for i = 1:60_000]
|
||||
train = data[1:50_000]
|
||||
|
@ -34,7 +35,7 @@ julia> data[1]
|
|||
Now we define our model, which will simply be a function from one to the other.
|
||||
|
||||
```julia
|
||||
m = Chain(
|
||||
m = @Chain(
|
||||
Input(784),
|
||||
Affine(128), relu,
|
||||
Affine( 64), relu,
|
||||
|
@ -46,7 +47,7 @@ model = mxnet(m) # Convert to MXNet
|
|||
We can try this out on our data already:
|
||||
|
||||
```julia
|
||||
julia> model(data[1][1])
|
||||
julia> model(tobatch(data[1][1]))
|
||||
10-element Array{Float64,1}:
|
||||
0.10614
|
||||
0.0850447
|
||||
|
@ -57,7 +58,8 @@ julia> model(data[1][1])
|
|||
The model gives a probability of about 0.1 to each class – which is a way of saying, "I have no idea". This isn't too surprising as we haven't shown it any data yet. This is easy to fix:
|
||||
|
||||
```julia
|
||||
Flux.train!(model, train, test, η = 1e-4)
|
||||
Flux.train!(model, train, η = 1e-3,
|
||||
cb = [()->@show accuracy(m, test)])
|
||||
```
|
||||
|
||||
The training step takes about 5 minutes (to make it faster we can do smarter things like batching). If you run this code in Juno, you'll see a progress meter, which you can hover over to see the remaining computation time.
|
||||
|
@ -80,7 +82,7 @@ Notice the class at 93%, suggesting our model is very confident about this image
|
|||
julia> onecold(data[1][2], 0:9)
|
||||
5
|
||||
|
||||
julia> onecold(model(data[1][1]), 0:9)
|
||||
julia> onecold(model(tobatch(data[1][1])), 0:9)
|
||||
5
|
||||
```
|
||||
|
||||
|
|
Loading…
Reference in New Issue