Flux.jl/examples/char-rnn.jl

21 lines
600 B
Julia
Raw Normal View History

2016-10-28 23:16:39 +00:00
using Flux
2016-10-29 22:36:39 +00:00
getseqs(chars, alphabet) = sequences((onehot(Float32, char, alphabet) for char in chars), 50)
getbatches(chars, alphabet) = batches((getseqs(part, alphabet) for part in chunk(chars, 50))...)
2016-10-28 23:16:39 +00:00
2016-10-29 22:36:39 +00:00
input = readstring("$(homedir())/Downloads/shakespeare_input.txt")
const alphabet = unique(input)
2016-10-28 23:16:39 +00:00
2016-10-29 22:36:39 +00:00
train = zip(getbatches(input, alphabet), getbatches(input[2:end], alphabet))
2016-10-28 23:16:39 +00:00
model = Chain(
Input(length(alphabet)),
2016-10-30 11:41:52 +00:00
Recurrent(length(alphabet), 128, length(alphabet)),
2016-10-28 23:16:39 +00:00
softmax)
2016-10-29 22:36:39 +00:00
m = tf(unroll(model, 50))
2016-10-28 23:16:39 +00:00
2016-10-29 23:34:44 +00:00
Flux.train!(m, train, η = 0.1/50)
2016-10-28 23:16:39 +00:00
2016-10-29 22:36:39 +00:00
map(c->onecold(c, alphabet), m(train[1][1][1]))