diff --git a/examples/char-rnn.jl b/examples/char-rnn.jl index 5898d31e..78c04033 100644 --- a/examples/char-rnn.jl +++ b/examples/char-rnn.jl @@ -4,14 +4,16 @@ import StatsBase: wsample nunroll = 50 nbatch = 50 -getseqs(chars, alphabet) = sequences((onehot(Float32, char, alphabet) for char in chars), nunroll) -getbatches(chars, alphabet) = batches((getseqs(part, alphabet) for part in chunk(chars, nbatch))...) +getseqs(chars, alphabet) = + sequences((onehot(Float32, char, alphabet) for char in chars), nunroll) +getbatches(chars, alphabet) = + batches((getseqs(part, alphabet) for part in chunk(chars, nbatch))...) input = readstring("$(homedir())/Downloads/shakespeare_input.txt"); alphabet = unique(input) N = length(alphabet) -Xs, Ys = getbatches(input, alphabet), getbatches(input[2:end], alphabet) +train = zip(getbatches(input, alphabet), getbatches(input[2:end], alphabet)) model = Chain( Input(N), @@ -20,17 +22,17 @@ model = Chain( Affine(256, N), softmax) -m = tf(unroll(model, nunroll)) +m = mxnet(unroll(model, nunroll)) -@time Flux.train!(m, Xs, Ys, η = 0.1, epoch = 1) +@time Flux.train!(m, train, η = 1, epoch = 1) function sample(model, n, temp = 1) s = [rand(alphabet)] - m = tf(unroll(model, 1)) - for i = 1:n - push!(s, wsample(alphabet, softmax(m(Seq((onehot(Float32, s[end], alphabet),)))[1]./temp))) + m = mxnet(unroll1(model)) + for i = 1:n-1 + push!(s, wsample(alphabet, softmax(m(unsqueeze(onehot(s[end], alphabet)))./temp)[1,:])) end return string(s...) end -sample(model[1:end-1], 100) +s = sample(model[1:end-1], 100)