fix up char-rnn
This commit is contained in:
parent
f7f8124a78
commit
1edabe6052
@ -4,14 +4,16 @@ import StatsBase: wsample
|
||||
nunroll = 50
|
||||
nbatch = 50
|
||||
|
||||
getseqs(chars, alphabet) = sequences((onehot(Float32, char, alphabet) for char in chars), nunroll)
|
||||
getbatches(chars, alphabet) = batches((getseqs(part, alphabet) for part in chunk(chars, nbatch))...)
|
||||
getseqs(chars, alphabet) =
|
||||
sequences((onehot(Float32, char, alphabet) for char in chars), nunroll)
|
||||
getbatches(chars, alphabet) =
|
||||
batches((getseqs(part, alphabet) for part in chunk(chars, nbatch))...)
|
||||
|
||||
input = readstring("$(homedir())/Downloads/shakespeare_input.txt");
|
||||
alphabet = unique(input)
|
||||
N = length(alphabet)
|
||||
|
||||
Xs, Ys = getbatches(input, alphabet), getbatches(input[2:end], alphabet)
|
||||
train = zip(getbatches(input, alphabet), getbatches(input[2:end], alphabet))
|
||||
|
||||
model = Chain(
|
||||
Input(N),
|
||||
@ -20,17 +22,17 @@ model = Chain(
|
||||
Affine(256, N),
|
||||
softmax)
|
||||
|
||||
m = tf(unroll(model, nunroll))
|
||||
m = mxnet(unroll(model, nunroll))
|
||||
|
||||
@time Flux.train!(m, Xs, Ys, η = 0.1, epoch = 1)
|
||||
@time Flux.train!(m, train, η = 1, epoch = 1)
|
||||
|
||||
function sample(model, n, temp = 1)
|
||||
s = [rand(alphabet)]
|
||||
m = tf(unroll(model, 1))
|
||||
for i = 1:n
|
||||
push!(s, wsample(alphabet, softmax(m(Seq((onehot(Float32, s[end], alphabet),)))[1]./temp)))
|
||||
m = mxnet(unroll1(model))
|
||||
for i = 1:n-1
|
||||
push!(s, wsample(alphabet, softmax(m(unsqueeze(onehot(s[end], alphabet)))./temp)[1,:]))
|
||||
end
|
||||
return string(s...)
|
||||
end
|
||||
|
||||
sample(model[1:end-1], 100)
|
||||
s = sample(model[1:end-1], 100)
|
||||
|
Loading…
Reference in New Issue
Block a user