Flux.jl/examples/char-rnn.jl

39 lines
939 B
Julia
Raw Normal View History

2016-10-28 23:16:39 +00:00
using Flux
import StatsBase: wsample
2016-10-28 23:16:39 +00:00
2016-12-15 21:07:07 +00:00
nunroll = 50
nbatch = 50
getseqs(chars, alphabet) = sequences((onehot(Float32, char, alphabet) for char in chars), nunroll)
getbatches(chars, alphabet) = batches((getseqs(part, alphabet) for part in chunk(chars, nbatch))...)
2016-10-28 23:16:39 +00:00
2016-10-29 22:36:39 +00:00
input = readstring("$(homedir())/Downloads/shakespeare_input.txt")
2016-10-30 14:12:32 +00:00
alphabet = unique(input)
N = length(alphabet)
2016-10-28 23:16:39 +00:00
2016-10-30 14:12:32 +00:00
Xs, Ys = getbatches(input, alphabet), getbatches(input[2:end], alphabet)
2016-10-28 23:16:39 +00:00
model = Chain(
2016-10-30 14:12:32 +00:00
Input(N),
2016-11-08 19:22:10 +00:00
LSTM(N, 256),
LSTM(256, 256),
2016-11-14 22:16:00 +00:00
Affine(256, N),
2016-10-28 23:16:39 +00:00
softmax)
2016-12-15 21:07:07 +00:00
m = tf(unroll(model, nunroll))
2016-10-28 23:16:39 +00:00
2016-11-08 19:22:10 +00:00
@time Flux.train!(m, Xs, Ys, η = 0.1, epoch = 1)
2016-10-28 23:16:39 +00:00
2016-10-30 16:07:29 +00:00
string(map(c -> onecold(c, alphabet), m(first(first(Xs))))...)
2016-12-13 12:27:50 +00:00
function sample(model, n, temp = 1)
2016-10-30 16:07:29 +00:00
s = [rand(alphabet)]
m = tf(unroll(model, 1))
for i = 1:n
2016-12-13 12:27:50 +00:00
push!(s, wsample(alphabet, softmax(m(Seq((onehot(Float32, s[end], alphabet),)))[1]./temp)))
2016-10-30 16:07:29 +00:00
end
return string(s...)
end
2016-11-08 19:22:10 +00:00
sample(model, 100)