From 8db7df3f51067aebfb297f13a8c4f54dd10791ed Mon Sep 17 00:00:00 2001 From: Mike J Innes Date: Sun, 30 Oct 2016 14:12:32 +0000 Subject: [PATCH] api tweaks --- examples/char-rnn.jl | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/examples/char-rnn.jl b/examples/char-rnn.jl index bef8fa11..218edfc7 100644 --- a/examples/char-rnn.jl +++ b/examples/char-rnn.jl @@ -4,17 +4,18 @@ getseqs(chars, alphabet) = sequences((onehot(Float32, char, alphabet) for char i getbatches(chars, alphabet) = batches((getseqs(part, alphabet) for part in chunk(chars, 50))...) input = readstring("$(homedir())/Downloads/shakespeare_input.txt") -const alphabet = unique(input) +alphabet = unique(input) +N = length(alphabet) -train = zip(getbatches(input, alphabet), getbatches(input[2:end], alphabet)) +Xs, Ys = getbatches(input, alphabet), getbatches(input[2:end], alphabet) model = Chain( - Input(length(alphabet)), - Recurrent(length(alphabet), 128, length(alphabet)), + Input(N), + Recurrent(N, 128, N), softmax) m = tf(unroll(model, 50)) -Flux.train!(m, train, η = 0.1/50) +Flux.train!(m, Xs, Ys, η = 0.2e-3, epoch = 1) -map(c->onecold(c, alphabet), m(train[1][1][1])) +map(c->onecold(c, alphabet), m(first(first(first(train)))))