diff --git a/examples/char-rnn.jl b/examples/char-rnn.jl index 2802e232..e8ed72ab 100644 --- a/examples/char-rnn.jl +++ b/examples/char-rnn.jl @@ -1,23 +1,22 @@ using Flux +using Flux: onehot, logloss, unsqueeze +using Flux.Batches: Batch, tobatch, seqs, chunk import StatsBase: wsample nunroll = 50 nbatch = 50 -getseqs(chars, alphabet) = - sequences((onehot(Float32, char, alphabet) for char in chars), nunroll) -getbatches(chars, alphabet) = - batches((getseqs(part, alphabet) for part in chunk(chars, nbatch))...) +encode(input) = seqs((onehot(ch, alphabet) for ch in input), nunroll) -input = readstring("$(homedir())/Downloads/shakespeare_input.txt"); +cd(@__DIR__) +input = readstring("shakespeare_input.txt"); alphabet = unique(input) N = length(alphabet) -train = zip(getbatches(input, alphabet), getbatches(input[2:end], alphabet)) -eval = tobatch.(first(drop(train, 5))) +Xs = (Batch(ss) for ss in zip(encode.(chunk(input, 50))...)) +Ys = (Batch(ss) for ss in zip(encode.(chunk(input[2:end], 50))...)) model = Chain( - Input(N), LSTM(N, 256), LSTM(256, 256), Affine(256, N), @@ -25,9 +24,10 @@ model = Chain( m = mxnet(unroll(model, nunroll)) +eval = tobatch.(first.(drop.((Xs, Ys), 5))) evalcb = () -> @show logloss(m(eval[1]), eval[2]) -@time Flux.train!(m, train, η = 0.1, loss = logloss, cb = [evalcb]) +# @time Flux.train!(m, zip(Xs, Ys), η = 0.001, loss = logloss, cb = [evalcb], epoch = 10) function sample(model, n, temp = 1) s = [rand(alphabet)] @@ -38,4 +38,4 @@ function sample(model, n, temp = 1) return string(s...) end -s = sample(model[1:end-1], 100) +# s = sample(model[1:end-1], 100)