diff --git a/examples/char-rnn.jl b/examples/char-rnn.jl index 1da9e512..7933c6eb 100644 --- a/examples/char-rnn.jl +++ b/examples/char-rnn.jl @@ -1,7 +1,5 @@ using Flux -using Juno - getseqs(chars, alphabet) = sequences((onehot(Float32, char, alphabet) for char in chars), 50) getbatches(chars, alphabet) = batches((getseqs(part, alphabet) for part in chunk(chars, 50))...) @@ -17,6 +15,6 @@ model = Chain( m = tf(unroll(model, 50)) -Flux.train!(m, train, η = 0.1/50, epoch = 5) +Flux.train!(m, train, η = 0.1/50) map(c->onecold(c, alphabet), m(train[1][1][1]))