fix up char-rnn

This commit is contained in:
Mike J Innes 2017-04-27 17:28:32 +01:00
parent f7f8124a78
commit 1edabe6052

View File

@ -4,14 +4,16 @@ import StatsBase: wsample
nunroll = 50 nunroll = 50
nbatch = 50 nbatch = 50
getseqs(chars, alphabet) = sequences((onehot(Float32, char, alphabet) for char in chars), nunroll) getseqs(chars, alphabet) =
getbatches(chars, alphabet) = batches((getseqs(part, alphabet) for part in chunk(chars, nbatch))...) sequences((onehot(Float32, char, alphabet) for char in chars), nunroll)
getbatches(chars, alphabet) =
batches((getseqs(part, alphabet) for part in chunk(chars, nbatch))...)
input = readstring("$(homedir())/Downloads/shakespeare_input.txt"); input = readstring("$(homedir())/Downloads/shakespeare_input.txt");
alphabet = unique(input) alphabet = unique(input)
N = length(alphabet) N = length(alphabet)
Xs, Ys = getbatches(input, alphabet), getbatches(input[2:end], alphabet) train = zip(getbatches(input, alphabet), getbatches(input[2:end], alphabet))
model = Chain( model = Chain(
Input(N), Input(N),
@ -20,17 +22,17 @@ model = Chain(
Affine(256, N), Affine(256, N),
softmax) softmax)
m = tf(unroll(model, nunroll)) m = mxnet(unroll(model, nunroll))
@time Flux.train!(m, Xs, Ys, η = 0.1, epoch = 1) @time Flux.train!(m, train, η = 1, epoch = 1)
function sample(model, n, temp = 1) function sample(model, n, temp = 1)
s = [rand(alphabet)] s = [rand(alphabet)]
m = tf(unroll(model, 1)) m = mxnet(unroll1(model))
for i = 1:n for i = 1:n-1
push!(s, wsample(alphabet, softmax(m(Seq((onehot(Float32, s[end], alphabet),)))[1]./temp))) push!(s, wsample(alphabet, softmax(m(unsqueeze(onehot(s[end], alphabet)))./temp)[1,:]))
end end
return string(s...) return string(s...)
end end
sample(model[1:end-1], 100) s = sample(model[1:end-1], 100)