char-rnn updates

This commit is contained in:
Mike J Innes 2016-11-08 19:22:10 +00:00
parent ad6e6b4116
commit 5b50d58381

View File

@ -12,14 +12,14 @@ Xs, Ys = getbatches(input, alphabet), getbatches(input[2:end], alphabet)
model = Chain( model = Chain(
Input(N), Input(N),
LSTM(N, 128), LSTM(N, 256),
LSTM(128, 128), LSTM(256, 256),
Dense(128, N), Dense(256, N),
softmax) softmax)
m = tf(unroll(model, 50)) m = tf(unroll(model, 50));
Flux.train!(m, Xs, Ys, η = 0.01, epoch = 1) @time Flux.train!(m, Xs, Ys, η = 0.1, epoch = 1)
string(map(c -> onecold(c, alphabet), m(first(first(Xs))))...) string(map(c -> onecold(c, alphabet), m(first(first(Xs))))...)
@ -32,4 +32,4 @@ function sample(model, n)
return string(s...) return string(s...)
end end
sample(model, 100) |> println sample(model, 100)