2016-10-28 23:16:39 +00:00
|
|
|
using Flux
|
|
|
|
|
2016-10-29 22:36:39 +00:00
|
|
|
getseqs(chars, alphabet) = sequences((onehot(Float32, char, alphabet) for char in chars), 50)
|
|
|
|
getbatches(chars, alphabet) = batches((getseqs(part, alphabet) for part in chunk(chars, 50))...)
|
2016-10-28 23:16:39 +00:00
|
|
|
|
2016-10-29 22:36:39 +00:00
|
|
|
input = readstring("$(homedir())/Downloads/shakespeare_input.txt")
|
2016-10-30 14:12:32 +00:00
|
|
|
alphabet = unique(input)
|
|
|
|
N = length(alphabet)
|
2016-10-28 23:16:39 +00:00
|
|
|
|
2016-10-30 14:12:32 +00:00
|
|
|
Xs, Ys = getbatches(input, alphabet), getbatches(input[2:end], alphabet)
|
2016-10-28 23:16:39 +00:00
|
|
|
|
|
|
|
model = Chain(
|
2016-10-30 14:12:32 +00:00
|
|
|
Input(N),
|
|
|
|
Recurrent(N, 128, N),
|
2016-10-28 23:16:39 +00:00
|
|
|
softmax)
|
|
|
|
|
2016-10-29 22:36:39 +00:00
|
|
|
m = tf(unroll(model, 50))
|
2016-10-28 23:16:39 +00:00
|
|
|
|
2016-10-30 14:12:32 +00:00
|
|
|
Flux.train!(m, Xs, Ys, η = 0.2e-3, epoch = 1)
|
2016-10-28 23:16:39 +00:00
|
|
|
|
2016-10-30 14:12:32 +00:00
|
|
|
map(c->onecold(c, alphabet), m(first(first(first(train)))))
|