char-rnn example
This commit is contained in:
parent
3e35cdc462
commit
209c089e5a
@ -16,7 +16,8 @@ makedocs(modules=[Flux],
|
|||||||
"Batching" => "apis/batching.md",
|
"Batching" => "apis/batching.md",
|
||||||
"Backends" => "apis/backends.md"],
|
"Backends" => "apis/backends.md"],
|
||||||
"In Action" => [
|
"In Action" => [
|
||||||
"Logistic Regression" => "examples/logreg.md"],
|
"Logistic Regression" => "examples/logreg.md"
|
||||||
|
"Char RNN" => "examples/char-rnn.md"],
|
||||||
"Contributing & Help" => "contributing.md",
|
"Contributing & Help" => "contributing.md",
|
||||||
"Internals" => "internals.md"])
|
"Internals" => "internals.md"])
|
||||||
|
|
||||||
|
40
docs/src/examples/char-rnn.md
Normal file
40
docs/src/examples/char-rnn.md
Normal file
@ -0,0 +1,40 @@
|
|||||||
|
# Char RNN
|
||||||
|
|
||||||
|
```julia
|
||||||
|
using Flux
|
||||||
|
import StatsBase: wsample
|
||||||
|
|
||||||
|
nunroll = 50
|
||||||
|
nbatch = 50
|
||||||
|
|
||||||
|
getseqs(chars, alphabet) = sequences((onehot(Float32, char, alphabet) for char in chars), nunroll)
|
||||||
|
getbatches(chars, alphabet) = batches((getseqs(part, alphabet) for part in chunk(chars, nbatch))...)
|
||||||
|
|
||||||
|
input = readstring("$(homedir())/Downloads/shakespeare_input.txt")
|
||||||
|
alphabet = unique(input)
|
||||||
|
N = length(alphabet)
|
||||||
|
|
||||||
|
Xs, Ys = getbatches(input, alphabet), getbatches(input[2:end], alphabet)
|
||||||
|
|
||||||
|
model = Chain(
|
||||||
|
Input(N),
|
||||||
|
LSTM(N, 256),
|
||||||
|
LSTM(256, 256),
|
||||||
|
Affine(256, N),
|
||||||
|
softmax)
|
||||||
|
|
||||||
|
m = tf(unroll(model, nunroll))
|
||||||
|
|
||||||
|
@time Flux.train!(m, Xs, Ys, η = 0.1, epoch = 1)
|
||||||
|
|
||||||
|
function sample(model, n, temp = 1)
|
||||||
|
s = [rand(alphabet)]
|
||||||
|
m = tf(unroll(model, 1))
|
||||||
|
for i = 1:n
|
||||||
|
push!(s, wsample(alphabet, softmax(m(Seq((onehot(Float32, s[end], alphabet),)))[1]./temp)))
|
||||||
|
end
|
||||||
|
return string(s...)
|
||||||
|
end
|
||||||
|
|
||||||
|
sample(model, 100)
|
||||||
|
```
|
Loading…
Reference in New Issue
Block a user