batched training for char-rnn
This commit is contained in:
parent
f0779fc77c
commit
73ff5b4201
@ -1,20 +1,22 @@
|
||||
using Flux
|
||||
|
||||
using Juno
|
||||
|
||||
getseqs(chars, alphabet) = sequences((onehot(Float32, char, alphabet) for char in chars), 50)
|
||||
getbatches(chars, alphabet) = batches((getseqs(part, alphabet) for part in chunk(chars, 50))...)
|
||||
|
||||
input = readstring("$(homedir())/Downloads/shakespeare_input.txt")
|
||||
const alphabet = unique(input)
|
||||
|
||||
alphabet = unique(input)
|
||||
|
||||
getseqs(data, n) = (Seq(onehot(Float32, char, alphabet) for char in chunk) for chunk in chunks(data, n))
|
||||
|
||||
data = zip(getseqs(input, 50), getseqs(input[2:end], 50))
|
||||
train = zip(getbatches(input, alphabet), getbatches(input[2:end], alphabet))
|
||||
|
||||
model = Chain(
|
||||
Input(length(alphabet)),
|
||||
Flux.Recurrent(length(alphabet), 128, length(alphabet)),
|
||||
softmax)
|
||||
|
||||
unrolled = unroll(model, 50)
|
||||
m = tf(unroll(model, 50))
|
||||
|
||||
m = tf(unrolled)
|
||||
Flux.train!(m, train, η = 0.1/50, epoch = 5)
|
||||
|
||||
Flux.train!(m, data)
|
||||
map(c->onecold(c, alphabet), m(train[1][1][1]))
|
||||
|
@ -10,6 +10,7 @@ import Juno: Tree, Row
|
||||
|
||||
include("model.jl")
|
||||
include("utils.jl")
|
||||
include("data.jl")
|
||||
|
||||
include("compiler/graph.jl")
|
||||
include("compiler/diff.jl")
|
||||
|
@ -44,7 +44,6 @@ end
|
||||
function Flux.train!(m::SeqModel, train; epoch = 1, η = 0.1,
|
||||
loss = (y, y′) -> reduce_sum((y - y′).^2)/2,
|
||||
opt = TensorFlow.train.GradientDescentOptimizer(η))
|
||||
i = 0
|
||||
Y = placeholder(Float32)
|
||||
Loss = loss(m.m.output[end], Y)
|
||||
minimize_op = TensorFlow.train.minimize(opt, Loss)
|
||||
@ -52,12 +51,8 @@ function Flux.train!(m::SeqModel, train; epoch = 1, η = 0.1,
|
||||
info("Epoch $e\n")
|
||||
@progress for (x, y) in train
|
||||
y, cur_loss, _ = run(m.m.session, vcat(m.m.output[end], Loss, minimize_op),
|
||||
merge(Dict(m.m.inputs[end]=>batchone(x), Y=>batchone(y)),
|
||||
merge(Dict(m.m.inputs[end]=>x, Y=>y),
|
||||
Dict(zip(m.m.inputs[1:end-1], m.state))))
|
||||
if i % 5000 == 0
|
||||
@show y
|
||||
end
|
||||
i += 1
|
||||
end
|
||||
end
|
||||
end
|
||||
|
@ -61,6 +61,8 @@ function build_backward(body, x, params = [])
|
||||
syntax(cse(back))
|
||||
end
|
||||
|
||||
import Lazy: groupby
|
||||
|
||||
function process_type(ex)
|
||||
@capture(ex, type T_ fs__ end)
|
||||
@destruct [params = false || [],
|
||||
|
18
src/data.jl
Normal file
18
src/data.jl
Normal file
@ -0,0 +1,18 @@
|
||||
export onehot, onecold, chunk, partition, batches, sequences
|
||||
|
||||
onehot(T::Type, label, labels) = T[i == label for i in labels]
|
||||
onehot(label, labels) = onehot(Int, label, labels)
|
||||
onecold(pred, labels = 1:length(pred)) = labels[findfirst(pred, maximum(pred))]
|
||||
|
||||
using Iterators
|
||||
import Iterators: partition
|
||||
|
||||
export partition
|
||||
|
||||
_partition(r::UnitRange, step::Integer) = (step*(i-1)+1:step*i for i in 1:(r.stop÷step))
|
||||
_partition(xs, step) = (xs[i] for i in _partition(1:length(xs), step))
|
||||
|
||||
chunk(xs, n) = _partition(xs, length(xs)÷n)
|
||||
|
||||
batches(xs...) = (Batch(x) for x in zip(xs...))
|
||||
sequences(xs, len) = (Seq(x) for x in partition(xs, len))
|
@ -1,14 +1,7 @@
|
||||
export AArray, onehot, onecold, chunks
|
||||
export AArray
|
||||
|
||||
const AArray = AbstractArray
|
||||
|
||||
onehot(T::Type, label, labels) = T[i == label for i in labels]
|
||||
onehot(label, labels) = onehot(Int, label, labels)
|
||||
onecold(pred, labels = 1:length(pred)) = labels[findfirst(pred, maximum(pred))]
|
||||
|
||||
chunks(r::UnitRange, step::Integer) = (step*(i-1)+1:step*i for i in 1:(r.stop÷step))
|
||||
chunks(xs, step) = (xs[i] for i in chunks(1:length(xs), step))
|
||||
|
||||
initn(dims...) = randn(Float32, dims...)/1000
|
||||
|
||||
function train!(m, train, test = []; epoch = 1, batch = 10, η = 0.1)
|
||||
|
Loading…
Reference in New Issue
Block a user