api tweaks
This commit is contained in:
parent
b443425c6d
commit
8db7df3f51
@ -4,17 +4,18 @@ getseqs(chars, alphabet) = sequences((onehot(Float32, char, alphabet) for char i
|
||||
getbatches(chars, alphabet) = batches((getseqs(part, alphabet) for part in chunk(chars, 50))...)
|
||||
|
||||
input = readstring("$(homedir())/Downloads/shakespeare_input.txt")
|
||||
const alphabet = unique(input)
|
||||
alphabet = unique(input)
|
||||
N = length(alphabet)
|
||||
|
||||
train = zip(getbatches(input, alphabet), getbatches(input[2:end], alphabet))
|
||||
Xs, Ys = getbatches(input, alphabet), getbatches(input[2:end], alphabet)
|
||||
|
||||
model = Chain(
|
||||
Input(length(alphabet)),
|
||||
Recurrent(length(alphabet), 128, length(alphabet)),
|
||||
Input(N),
|
||||
Recurrent(N, 128, N),
|
||||
softmax)
|
||||
|
||||
m = tf(unroll(model, 50))
|
||||
|
||||
Flux.train!(m, train, η = 0.1/50)
|
||||
Flux.train!(m, Xs, Ys, η = 0.2e-3, epoch = 1)
|
||||
|
||||
map(c->onecold(c, alphabet), m(train[1][1][1]))
|
||||
map(c->onecold(c, alphabet), m(first(first(first(train)))))
|
||||
|
Loading…
Reference in New Issue
Block a user