Flux.jl/examples/char-rnn.jl

42 lines
1013 B
Julia
Raw Normal View History

2016-10-28 23:16:39 +00:00
using Flux
2017-06-28 16:15:44 +00:00
using Flux: onehot, logloss, unsqueeze
using Flux.Batches: Batch, tobatch, seqs, chunk
import StatsBase: wsample
2016-10-28 23:16:39 +00:00
2016-12-15 21:07:07 +00:00
nunroll = 50
nbatch = 50
2017-06-28 16:15:44 +00:00
encode(input) = seqs((onehot(ch, alphabet) for ch in input), nunroll)
2016-10-28 23:16:39 +00:00
2017-06-28 16:15:44 +00:00
cd(@__DIR__)
input = readstring("shakespeare_input.txt");
2016-10-30 14:12:32 +00:00
alphabet = unique(input)
N = length(alphabet)
2016-10-28 23:16:39 +00:00
2017-06-28 16:15:44 +00:00
Xs = (Batch(ss) for ss in zip(encode.(chunk(input, 50))...))
Ys = (Batch(ss) for ss in zip(encode.(chunk(input[2:end], 50))...))
2016-10-28 23:16:39 +00:00
2017-02-28 16:42:48 +00:00
model = Chain(
2016-11-08 19:22:10 +00:00
LSTM(N, 256),
LSTM(256, 256),
2016-11-14 22:16:00 +00:00
Affine(256, N),
2016-10-28 23:16:39 +00:00
softmax)
2017-04-27 16:28:32 +00:00
m = mxnet(unroll(model, nunroll))
2016-10-28 23:16:39 +00:00
2017-06-28 16:15:44 +00:00
eval = tobatch.(first.(drop.((Xs, Ys), 5)))
2017-05-01 13:23:48 +00:00
evalcb = () -> @show logloss(m(eval[1]), eval[2])
2017-06-28 16:15:44 +00:00
# @time Flux.train!(m, zip(Xs, Ys), η = 0.001, loss = logloss, cb = [evalcb], epoch = 10)
2016-10-28 23:16:39 +00:00
2016-12-13 12:27:50 +00:00
function sample(model, n, temp = 1)
2016-10-30 16:07:29 +00:00
s = [rand(alphabet)]
2017-05-01 11:30:37 +00:00
m = unroll1(model)
2017-04-27 16:28:32 +00:00
for i = 1:n-1
push!(s, wsample(alphabet, softmax(m(unsqueeze(onehot(s[end], alphabet)))./temp)[1,:]))
2016-10-30 16:07:29 +00:00
end
return string(s...)
end
2017-06-28 16:15:44 +00:00
# s = sample(model[1:end-1], 100)