sampling + tweaks

This commit is contained in:
Mike J Innes 2016-10-30 16:07:29 +00:00
parent 508364407e
commit 4517e41226
2 changed files with 17 additions and 3 deletions

View File

@ -11,11 +11,25 @@ Xs, Ys = getbatches(input, alphabet), getbatches(input[2:end], alphabet)
model = Chain(
Input(N),
Recurrent(N, 128, N),
Recurrent(N, 128),
Dense(128, N),
softmax)
m = tf(unroll(model, 50))
# Flux.train!(m, take(Xs,100), take(Ys,100),
# η = 0.1, epoch = 1)
Flux.train!(m, Xs, Ys, η = 0.2e-3, epoch = 1)
map(c->onecold(c, alphabet), m(first(first(first(train)))))
string(map(c -> onecold(c, alphabet), m(first(first(Xs))))...)
function sample(model, n)
s = [rand(alphabet)]
m = tf(unroll(model, 1))
for i = 1:n
push!(s, onecold(m(Seq((onehot(Float32, 'b', alphabet),)))[1], alphabet))
end
return string(s...)
end
sample(model, 100) |> println

View File

@ -2,7 +2,7 @@ export AArray
const AArray = AbstractArray
initn(dims...) = randn(Float32, dims...)/1000
initn(dims...) = randn(Float32, dims...)/10
function train!(m, train, test = []; epoch = 1, batch = 10, η = 0.1)
i = 0