finally producing something recognisable
This commit is contained in:
parent
4517e41226
commit
e35380940b
@ -1,4 +1,5 @@
|
|||||||
using Flux
|
using Flux
|
||||||
|
import StatsBase: wsample
|
||||||
|
|
||||||
getseqs(chars, alphabet) = sequences((onehot(Float32, char, alphabet) for char in chars), 50)
|
getseqs(chars, alphabet) = sequences((onehot(Float32, char, alphabet) for char in chars), 50)
|
||||||
getbatches(chars, alphabet) = batches((getseqs(part, alphabet) for part in chunk(chars, 50))...)
|
getbatches(chars, alphabet) = batches((getseqs(part, alphabet) for part in chunk(chars, 50))...)
|
||||||
@ -17,9 +18,7 @@ model = Chain(
|
|||||||
|
|
||||||
m = tf(unroll(model, 50))
|
m = tf(unroll(model, 50))
|
||||||
|
|
||||||
# Flux.train!(m, take(Xs,100), take(Ys,100),
|
Flux.train!(m, Xs, Ys, η = 0.01, epoch = 1)
|
||||||
# η = 0.1, epoch = 1)
|
|
||||||
Flux.train!(m, Xs, Ys, η = 0.2e-3, epoch = 1)
|
|
||||||
|
|
||||||
string(map(c -> onecold(c, alphabet), m(first(first(Xs))))...)
|
string(map(c -> onecold(c, alphabet), m(first(first(Xs))))...)
|
||||||
|
|
||||||
@ -27,7 +26,7 @@ function sample(model, n)
|
|||||||
s = [rand(alphabet)]
|
s = [rand(alphabet)]
|
||||||
m = tf(unroll(model, 1))
|
m = tf(unroll(model, 1))
|
||||||
for i = 1:n
|
for i = 1:n
|
||||||
push!(s, onecold(m(Seq((onehot(Float32, 'b', alphabet),)))[1], alphabet))
|
push!(s, wsample(alphabet, m(Seq((onehot(Float32, s[end], alphabet),)))[1]))
|
||||||
end
|
end
|
||||||
return string(s...)
|
return string(s...)
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user