2016-10-28 23:16:39 +00:00
|
|
|
using Flux
|
2016-10-30 18:32:16 +00:00
|
|
|
import StatsBase: wsample
|
2016-10-28 23:16:39 +00:00
|
|
|
|
2016-12-15 21:07:07 +00:00
|
|
|
nunroll = 50
|
|
|
|
nbatch = 50
|
|
|
|
|
2017-04-27 16:28:32 +00:00
|
|
|
getseqs(chars, alphabet) =
|
|
|
|
sequences((onehot(Float32, char, alphabet) for char in chars), nunroll)
|
|
|
|
getbatches(chars, alphabet) =
|
|
|
|
batches((getseqs(part, alphabet) for part in chunk(chars, nbatch))...)
|
2016-10-28 23:16:39 +00:00
|
|
|
|
2017-03-15 13:56:00 +00:00
|
|
|
input = readstring("$(homedir())/Downloads/shakespeare_input.txt");
|
2016-10-30 14:12:32 +00:00
|
|
|
alphabet = unique(input)
|
|
|
|
N = length(alphabet)
|
2016-10-28 23:16:39 +00:00
|
|
|
|
2017-04-27 16:28:32 +00:00
|
|
|
train = zip(getbatches(input, alphabet), getbatches(input[2:end], alphabet))
|
2016-10-28 23:16:39 +00:00
|
|
|
|
2017-02-28 16:42:48 +00:00
|
|
|
model = Chain(
|
2016-10-30 14:12:32 +00:00
|
|
|
Input(N),
|
2016-11-08 19:22:10 +00:00
|
|
|
LSTM(N, 256),
|
|
|
|
LSTM(256, 256),
|
2016-11-14 22:16:00 +00:00
|
|
|
Affine(256, N),
|
2016-10-28 23:16:39 +00:00
|
|
|
softmax)
|
|
|
|
|
2017-04-27 16:28:32 +00:00
|
|
|
m = mxnet(unroll(model, nunroll))
|
2016-10-28 23:16:39 +00:00
|
|
|
|
2017-04-28 16:14:35 +00:00
|
|
|
@time Flux.train!(m, train, η = 0.1, loss = Flux.logloss)
|
2016-10-28 23:16:39 +00:00
|
|
|
|
2016-12-13 12:27:50 +00:00
|
|
|
function sample(model, n, temp = 1)
|
2016-10-30 16:07:29 +00:00
|
|
|
s = [rand(alphabet)]
|
2017-05-01 11:30:37 +00:00
|
|
|
m = unroll1(model)
|
2017-04-27 16:28:32 +00:00
|
|
|
for i = 1:n-1
|
|
|
|
push!(s, wsample(alphabet, softmax(m(unsqueeze(onehot(s[end], alphabet)))./temp)[1,:]))
|
2016-10-30 16:07:29 +00:00
|
|
|
end
|
|
|
|
return string(s...)
|
|
|
|
end
|
|
|
|
|
2017-04-27 16:28:32 +00:00
|
|
|
s = sample(model[1:end-1], 100)
|