char-rnn updates
This commit is contained in:
parent
ad6e6b4116
commit
5b50d58381
|
@ -12,14 +12,14 @@ Xs, Ys = getbatches(input, alphabet), getbatches(input[2:end], alphabet)
|
|||
|
||||
model = Chain(
|
||||
Input(N),
|
||||
LSTM(N, 128),
|
||||
LSTM(128, 128),
|
||||
Dense(128, N),
|
||||
LSTM(N, 256),
|
||||
LSTM(256, 256),
|
||||
Dense(256, N),
|
||||
softmax)
|
||||
|
||||
m = tf(unroll(model, 50))
|
||||
m = tf(unroll(model, 50));
|
||||
|
||||
Flux.train!(m, Xs, Ys, η = 0.01, epoch = 1)
|
||||
@time Flux.train!(m, Xs, Ys, η = 0.1, epoch = 1)
|
||||
|
||||
string(map(c -> onecold(c, alphabet), m(first(first(Xs))))...)
|
||||
|
||||
|
@ -32,4 +32,4 @@ function sample(model, n)
|
|||
return string(s...)
|
||||
end
|
||||
|
||||
sample(model, 100) |> println
|
||||
sample(model, 100)
|
||||
|
|
Loading…
Reference in New Issue