char-rnn temperature
This commit is contained in:
parent
2aa8dfc208
commit
a63cd826c2
@ -23,11 +23,11 @@ m = tf(unroll(model, 50));
|
|||||||
|
|
||||||
string(map(c -> onecold(c, alphabet), m(first(first(Xs))))...)
|
string(map(c -> onecold(c, alphabet), m(first(first(Xs))))...)
|
||||||
|
|
||||||
function sample(model, n)
|
function sample(model, n, temp = 1)
|
||||||
s = [rand(alphabet)]
|
s = [rand(alphabet)]
|
||||||
m = tf(unroll(model, 1))
|
m = tf(unroll(model, 1))
|
||||||
for i = 1:n
|
for i = 1:n
|
||||||
push!(s, wsample(alphabet, m(Seq((onehot(Float32, s[end], alphabet),)))[1]))
|
push!(s, wsample(alphabet, softmax(m(Seq((onehot(Float32, s[end], alphabet),)))[1]./temp)))
|
||||||
end
|
end
|
||||||
return string(s...)
|
return string(s...)
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user