From e35380940b6c5f134a8b9f475657d2bdf38b1d4a Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Sun, 30 Oct 2016 18:32:16 +0000 Subject: [PATCH] finally producing something recognisable --- examples/char-rnn.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/examples/char-rnn.jl b/examples/char-rnn.jl index 2a77b3a1..01cd0394 100644 --- a/examples/char-rnn.jl +++ b/examples/char-rnn.jl @@ -1,4 +1,5 @@ using Flux +import StatsBase: wsample getseqs(chars, alphabet) = sequences((onehot(Float32, char, alphabet) for char in chars), 50) getbatches(chars, alphabet) = batches((getseqs(part, alphabet) for part in chunk(chars, 50))...) @@ -17,9 +18,7 @@ model = Chain( m = tf(unroll(model, 50)) -# Flux.train!(m, take(Xs,100), take(Ys,100), -# η = 0.1, epoch = 1) -Flux.train!(m, Xs, Ys, η = 0.2e-3, epoch = 1) +Flux.train!(m, Xs, Ys, η = 0.01, epoch = 1) string(map(c -> onecold(c, alphabet), m(first(first(Xs))))...) @@ -27,7 +26,7 @@ function sample(model, n) s = [rand(alphabet)] m = tf(unroll(model, 1)) for i = 1:n - push!(s, onecold(m(Seq((onehot(Float32, 'b', alphabet),)))[1], alphabet)) + push!(s, wsample(alphabet, m(Seq((onehot(Float32, s[end], alphabet),)))[1])) end return string(s...) end