magic constants

This commit is contained in:
Mike J Innes 2016-12-15 21:07:07 +00:00
parent 752c4e39e8
commit d3e0f455c2
1 changed files with 6 additions and 3 deletions

View File

@ -1,8 +1,11 @@
using Flux
import StatsBase: wsample
getseqs(chars, alphabet) = sequences((onehot(Float32, char, alphabet) for char in chars), 50)
getbatches(chars, alphabet) = batches((getseqs(part, alphabet) for part in chunk(chars, 50))...)
nunroll = 50
nbatch = 50
getseqs(chars, alphabet) = sequences((onehot(Float32, char, alphabet) for char in chars), nunroll)
getbatches(chars, alphabet) = batches((getseqs(part, alphabet) for part in chunk(chars, nbatch))...)
input = readstring("$(homedir())/Downloads/shakespeare_input.txt")
alphabet = unique(input)
@ -17,7 +20,7 @@ model = Chain(
Affine(256, N),
softmax)
m = tf(unroll(model, 50));
m = tf(unroll(model, nunroll))
@time Flux.train!(m, Xs, Ys, η = 0.1, epoch = 1)