batched training for char-rnn

This commit is contained in:
Mike J Innes 2016-10-29 23:36:39 +01:00
parent f0779fc77c
commit 73ff5b4201
7 changed files with 34 additions and 22 deletions

View File

@ -1,2 +1,3 @@
julia 0.5-
TensorFlow
Iterators

View File

@ -1,20 +1,22 @@
using Flux
using Juno
getseqs(chars, alphabet) = sequences((onehot(Float32, char, alphabet) for char in chars), 50)
getbatches(chars, alphabet) = batches((getseqs(part, alphabet) for part in chunk(chars, 50))...)
input = readstring("$(homedir())/Downloads/shakespeare_input.txt")
const alphabet = unique(input)
alphabet = unique(input)
getseqs(data, n) = (Seq(onehot(Float32, char, alphabet) for char in chunk) for chunk in chunks(data, n))
data = zip(getseqs(input, 50), getseqs(input[2:end], 50))
train = zip(getbatches(input, alphabet), getbatches(input[2:end], alphabet))
model = Chain(
Input(length(alphabet)),
Flux.Recurrent(length(alphabet), 128, length(alphabet)),
softmax)
unrolled = unroll(model, 50)
m = tf(unroll(model, 50))
m = tf(unrolled)
Flux.train!(m, train, η = 0.1/50, epoch = 5)
Flux.train!(m, data)
map(c->onecold(c, alphabet), m(train[1][1][1]))

View File

@ -10,6 +10,7 @@ import Juno: Tree, Row
include("model.jl")
include("utils.jl")
include("data.jl")
include("compiler/graph.jl")
include("compiler/diff.jl")

View File

@ -44,7 +44,6 @@ end
function Flux.train!(m::SeqModel, train; epoch = 1, η = 0.1,
loss = (y, y) -> reduce_sum((y - y).^2)/2,
opt = TensorFlow.train.GradientDescentOptimizer(η))
i = 0
Y = placeholder(Float32)
Loss = loss(m.m.output[end], Y)
minimize_op = TensorFlow.train.minimize(opt, Loss)
@ -52,12 +51,8 @@ function Flux.train!(m::SeqModel, train; epoch = 1, η = 0.1,
info("Epoch $e\n")
@progress for (x, y) in train
y, cur_loss, _ = run(m.m.session, vcat(m.m.output[end], Loss, minimize_op),
merge(Dict(m.m.inputs[end]=>batchone(x), Y=>batchone(y)),
merge(Dict(m.m.inputs[end]=>x, Y=>y),
Dict(zip(m.m.inputs[1:end-1], m.state))))
if i % 5000 == 0
@show y
end
i += 1
end
end
end

View File

@ -61,6 +61,8 @@ function build_backward(body, x, params = [])
syntax(cse(back))
end
import Lazy: groupby
function process_type(ex)
@capture(ex, type T_ fs__ end)
@destruct [params = false || [],

18
src/data.jl Normal file
View File

@ -0,0 +1,18 @@
export onehot, onecold, chunk, partition, batches, sequences
onehot(T::Type, label, labels) = T[i == label for i in labels]
onehot(label, labels) = onehot(Int, label, labels)
onecold(pred, labels = 1:length(pred)) = labels[findfirst(pred, maximum(pred))]
using Iterators
import Iterators: partition
export partition
_partition(r::UnitRange, step::Integer) = (step*(i-1)+1:step*i for i in 1:(r.stop÷step))
_partition(xs, step) = (xs[i] for i in _partition(1:length(xs), step))
chunk(xs, n) = _partition(xs, length(xs)÷n)
batches(xs...) = (Batch(x) for x in zip(xs...))
sequences(xs, len) = (Seq(x) for x in partition(xs, len))

View File

@ -1,14 +1,7 @@
export AArray, onehot, onecold, chunks
export AArray
const AArray = AbstractArray
onehot(T::Type, label, labels) = T[i == label for i in labels]
onehot(label, labels) = onehot(Int, label, labels)
onecold(pred, labels = 1:length(pred)) = labels[findfirst(pred, maximum(pred))]
chunks(r::UnitRange, step::Integer) = (step*(i-1)+1:step*i for i in 1:(r.stop÷step))
chunks(xs, step) = (xs[i] for i in chunks(1:length(xs), step))
initn(dims...) = randn(Float32, dims...)/1000
function train!(m, train, test = []; epoch = 1, batch = 10, η = 0.1)