char-rnn temperature

This commit is contained in:
Mike J Innes 2016-12-13 12:27:50 +00:00
parent 2aa8dfc208
commit a63cd826c2
1 changed files with 2 additions and 2 deletions

View File

@ -23,11 +23,11 @@ m = tf(unroll(model, 50));
string(map(c -> onecold(c, alphabet), m(first(first(Xs))))...)
function sample(model, n)
function sample(model, n, temp = 1)
s = [rand(alphabet)]
m = tf(unroll(model, 1))
for i = 1:n
push!(s, wsample(alphabet, m(Seq((onehot(Float32, s[end], alphabet),)))[1]))
push!(s, wsample(alphabet, softmax(m(Seq((onehot(Float32, s[end], alphabet),)))[1]./temp)))
end
return string(s...)
end