sampling + tweaks
This commit is contained in:
parent
508364407e
commit
4517e41226
@ -11,11 +11,25 @@ Xs, Ys = getbatches(input, alphabet), getbatches(input[2:end], alphabet)
|
||||
|
||||
model = Chain(
|
||||
Input(N),
|
||||
Recurrent(N, 128, N),
|
||||
Recurrent(N, 128),
|
||||
Dense(128, N),
|
||||
softmax)
|
||||
|
||||
m = tf(unroll(model, 50))
|
||||
|
||||
# Flux.train!(m, take(Xs,100), take(Ys,100),
|
||||
# η = 0.1, epoch = 1)
|
||||
Flux.train!(m, Xs, Ys, η = 0.2e-3, epoch = 1)
|
||||
|
||||
map(c->onecold(c, alphabet), m(first(first(first(train)))))
|
||||
string(map(c -> onecold(c, alphabet), m(first(first(Xs))))...)
|
||||
|
||||
function sample(model, n)
|
||||
s = [rand(alphabet)]
|
||||
m = tf(unroll(model, 1))
|
||||
for i = 1:n
|
||||
push!(s, onecold(m(Seq((onehot(Float32, 'b', alphabet),)))[1], alphabet))
|
||||
end
|
||||
return string(s...)
|
||||
end
|
||||
|
||||
sample(model, 100) |> println
|
||||
|
@ -2,7 +2,7 @@ export AArray
|
||||
|
||||
const AArray = AbstractArray
|
||||
|
||||
initn(dims...) = randn(Float32, dims...)/1000
|
||||
initn(dims...) = randn(Float32, dims...)/10
|
||||
|
||||
function train!(m, train, test = []; epoch = 1, batch = 10, η = 0.1)
|
||||
i = 0
|
||||
|
Loading…
Reference in New Issue
Block a user