sampling + tweaks
This commit is contained in:
parent
508364407e
commit
4517e41226
@ -11,11 +11,25 @@ Xs, Ys = getbatches(input, alphabet), getbatches(input[2:end], alphabet)
|
|||||||
|
|
||||||
model = Chain(
|
model = Chain(
|
||||||
Input(N),
|
Input(N),
|
||||||
Recurrent(N, 128, N),
|
Recurrent(N, 128),
|
||||||
|
Dense(128, N),
|
||||||
softmax)
|
softmax)
|
||||||
|
|
||||||
m = tf(unroll(model, 50))
|
m = tf(unroll(model, 50))
|
||||||
|
|
||||||
|
# Flux.train!(m, take(Xs,100), take(Ys,100),
|
||||||
|
# η = 0.1, epoch = 1)
|
||||||
Flux.train!(m, Xs, Ys, η = 0.2e-3, epoch = 1)
|
Flux.train!(m, Xs, Ys, η = 0.2e-3, epoch = 1)
|
||||||
|
|
||||||
map(c->onecold(c, alphabet), m(first(first(first(train)))))
|
string(map(c -> onecold(c, alphabet), m(first(first(Xs))))...)
|
||||||
|
|
||||||
|
function sample(model, n)
|
||||||
|
s = [rand(alphabet)]
|
||||||
|
m = tf(unroll(model, 1))
|
||||||
|
for i = 1:n
|
||||||
|
push!(s, onecold(m(Seq((onehot(Float32, 'b', alphabet),)))[1], alphabet))
|
||||||
|
end
|
||||||
|
return string(s...)
|
||||||
|
end
|
||||||
|
|
||||||
|
sample(model, 100) |> println
|
||||||
|
@ -2,7 +2,7 @@ export AArray
|
|||||||
|
|
||||||
const AArray = AbstractArray
|
const AArray = AbstractArray
|
||||||
|
|
||||||
initn(dims...) = randn(Float32, dims...)/1000
|
initn(dims...) = randn(Float32, dims...)/10
|
||||||
|
|
||||||
function train!(m, train, test = []; epoch = 1, batch = 10, η = 0.1)
|
function train!(m, train, test = []; epoch = 1, batch = 10, η = 0.1)
|
||||||
i = 0
|
i = 0
|
||||||
|
Loading…
Reference in New Issue
Block a user