magic constants
This commit is contained in:
parent
752c4e39e8
commit
d3e0f455c2
@ -1,8 +1,11 @@
|
|||||||
using Flux
|
using Flux
|
||||||
import StatsBase: wsample
|
import StatsBase: wsample
|
||||||
|
|
||||||
getseqs(chars, alphabet) = sequences((onehot(Float32, char, alphabet) for char in chars), 50)
|
nunroll = 50
|
||||||
getbatches(chars, alphabet) = batches((getseqs(part, alphabet) for part in chunk(chars, 50))...)
|
nbatch = 50
|
||||||
|
|
||||||
|
getseqs(chars, alphabet) = sequences((onehot(Float32, char, alphabet) for char in chars), nunroll)
|
||||||
|
getbatches(chars, alphabet) = batches((getseqs(part, alphabet) for part in chunk(chars, nbatch))...)
|
||||||
|
|
||||||
input = readstring("$(homedir())/Downloads/shakespeare_input.txt")
|
input = readstring("$(homedir())/Downloads/shakespeare_input.txt")
|
||||||
alphabet = unique(input)
|
alphabet = unique(input)
|
||||||
@ -17,7 +20,7 @@ model = Chain(
|
|||||||
Affine(256, N),
|
Affine(256, N),
|
||||||
softmax)
|
softmax)
|
||||||
|
|
||||||
m = tf(unroll(model, 50));
|
m = tf(unroll(model, nunroll))
|
||||||
|
|
||||||
@time Flux.train!(m, Xs, Ys, η = 0.1, epoch = 1)
|
@time Flux.train!(m, Xs, Ys, η = 0.1, epoch = 1)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user