translation model updates
This commit is contained in:
parent
ea5d43ed77
commit
2be51b8c84
@ -1,23 +1,22 @@
|
||||
using Flux
|
||||
using Flux: onehot, logloss, unsqueeze
|
||||
using Flux.Batches: Batch, tobatch, seqs, chunk
|
||||
import StatsBase: wsample
|
||||
|
||||
nunroll = 50
|
||||
nbatch = 50
|
||||
|
||||
getseqs(chars, alphabet) =
|
||||
sequences((onehot(Float32, char, alphabet) for char in chars), nunroll)
|
||||
getbatches(chars, alphabet) =
|
||||
batches((getseqs(part, alphabet) for part in chunk(chars, nbatch))...)
|
||||
encode(input) = seqs((onehot(ch, alphabet) for ch in input), nunroll)
|
||||
|
||||
input = readstring("$(homedir())/Downloads/shakespeare_input.txt");
|
||||
cd(@__DIR__)
|
||||
input = readstring("shakespeare_input.txt");
|
||||
alphabet = unique(input)
|
||||
N = length(alphabet)
|
||||
|
||||
train = zip(getbatches(input, alphabet), getbatches(input[2:end], alphabet))
|
||||
eval = tobatch.(first(drop(train, 5)))
|
||||
Xs = (Batch(ss) for ss in zip(encode.(chunk(input, 50))...))
|
||||
Ys = (Batch(ss) for ss in zip(encode.(chunk(input[2:end], 50))...))
|
||||
|
||||
model = Chain(
|
||||
Input(N),
|
||||
LSTM(N, 256),
|
||||
LSTM(256, 256),
|
||||
Affine(256, N),
|
||||
@ -25,9 +24,10 @@ model = Chain(
|
||||
|
||||
m = mxnet(unroll(model, nunroll))
|
||||
|
||||
eval = tobatch.(first.(drop.((Xs, Ys), 5)))
|
||||
evalcb = () -> @show logloss(m(eval[1]), eval[2])
|
||||
|
||||
@time Flux.train!(m, train, η = 0.1, loss = logloss, cb = [evalcb])
|
||||
# @time Flux.train!(m, zip(Xs, Ys), η = 0.001, loss = logloss, cb = [evalcb], epoch = 10)
|
||||
|
||||
function sample(model, n, temp = 1)
|
||||
s = [rand(alphabet)]
|
||||
@ -38,4 +38,4 @@ function sample(model, n, temp = 1)
|
||||
return string(s...)
|
||||
end
|
||||
|
||||
s = sample(model[1:end-1], 100)
|
||||
# s = sample(model[1:end-1], 100)
|
||||
|
Loading…
Reference in New Issue
Block a user