batched training for char-rnn
This commit is contained in:
parent
f0779fc77c
commit
73ff5b4201
@ -1,20 +1,22 @@
|
|||||||
using Flux
|
using Flux
|
||||||
|
|
||||||
|
using Juno
|
||||||
|
|
||||||
|
getseqs(chars, alphabet) = sequences((onehot(Float32, char, alphabet) for char in chars), 50)
|
||||||
|
getbatches(chars, alphabet) = batches((getseqs(part, alphabet) for part in chunk(chars, 50))...)
|
||||||
|
|
||||||
input = readstring("$(homedir())/Downloads/shakespeare_input.txt")
|
input = readstring("$(homedir())/Downloads/shakespeare_input.txt")
|
||||||
|
const alphabet = unique(input)
|
||||||
|
|
||||||
alphabet = unique(input)
|
train = zip(getbatches(input, alphabet), getbatches(input[2:end], alphabet))
|
||||||
|
|
||||||
getseqs(data, n) = (Seq(onehot(Float32, char, alphabet) for char in chunk) for chunk in chunks(data, n))
|
|
||||||
|
|
||||||
data = zip(getseqs(input, 50), getseqs(input[2:end], 50))
|
|
||||||
|
|
||||||
model = Chain(
|
model = Chain(
|
||||||
Input(length(alphabet)),
|
Input(length(alphabet)),
|
||||||
Flux.Recurrent(length(alphabet), 128, length(alphabet)),
|
Flux.Recurrent(length(alphabet), 128, length(alphabet)),
|
||||||
softmax)
|
softmax)
|
||||||
|
|
||||||
unrolled = unroll(model, 50)
|
m = tf(unroll(model, 50))
|
||||||
|
|
||||||
m = tf(unrolled)
|
Flux.train!(m, train, η = 0.1/50, epoch = 5)
|
||||||
|
|
||||||
Flux.train!(m, data)
|
map(c->onecold(c, alphabet), m(train[1][1][1]))
|
||||||
|
@ -10,6 +10,7 @@ import Juno: Tree, Row
|
|||||||
|
|
||||||
include("model.jl")
|
include("model.jl")
|
||||||
include("utils.jl")
|
include("utils.jl")
|
||||||
|
include("data.jl")
|
||||||
|
|
||||||
include("compiler/graph.jl")
|
include("compiler/graph.jl")
|
||||||
include("compiler/diff.jl")
|
include("compiler/diff.jl")
|
||||||
|
@ -44,7 +44,6 @@ end
|
|||||||
function Flux.train!(m::SeqModel, train; epoch = 1, η = 0.1,
|
function Flux.train!(m::SeqModel, train; epoch = 1, η = 0.1,
|
||||||
loss = (y, y′) -> reduce_sum((y - y′).^2)/2,
|
loss = (y, y′) -> reduce_sum((y - y′).^2)/2,
|
||||||
opt = TensorFlow.train.GradientDescentOptimizer(η))
|
opt = TensorFlow.train.GradientDescentOptimizer(η))
|
||||||
i = 0
|
|
||||||
Y = placeholder(Float32)
|
Y = placeholder(Float32)
|
||||||
Loss = loss(m.m.output[end], Y)
|
Loss = loss(m.m.output[end], Y)
|
||||||
minimize_op = TensorFlow.train.minimize(opt, Loss)
|
minimize_op = TensorFlow.train.minimize(opt, Loss)
|
||||||
@ -52,12 +51,8 @@ function Flux.train!(m::SeqModel, train; epoch = 1, η = 0.1,
|
|||||||
info("Epoch $e\n")
|
info("Epoch $e\n")
|
||||||
@progress for (x, y) in train
|
@progress for (x, y) in train
|
||||||
y, cur_loss, _ = run(m.m.session, vcat(m.m.output[end], Loss, minimize_op),
|
y, cur_loss, _ = run(m.m.session, vcat(m.m.output[end], Loss, minimize_op),
|
||||||
merge(Dict(m.m.inputs[end]=>batchone(x), Y=>batchone(y)),
|
merge(Dict(m.m.inputs[end]=>x, Y=>y),
|
||||||
Dict(zip(m.m.inputs[1:end-1], m.state))))
|
Dict(zip(m.m.inputs[1:end-1], m.state))))
|
||||||
if i % 5000 == 0
|
|
||||||
@show y
|
|
||||||
end
|
|
||||||
i += 1
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -61,6 +61,8 @@ function build_backward(body, x, params = [])
|
|||||||
syntax(cse(back))
|
syntax(cse(back))
|
||||||
end
|
end
|
||||||
|
|
||||||
|
import Lazy: groupby
|
||||||
|
|
||||||
function process_type(ex)
|
function process_type(ex)
|
||||||
@capture(ex, type T_ fs__ end)
|
@capture(ex, type T_ fs__ end)
|
||||||
@destruct [params = false || [],
|
@destruct [params = false || [],
|
||||||
|
18
src/data.jl
Normal file
18
src/data.jl
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
export onehot, onecold, chunk, partition, batches, sequences
|
||||||
|
|
||||||
|
onehot(T::Type, label, labels) = T[i == label for i in labels]
|
||||||
|
onehot(label, labels) = onehot(Int, label, labels)
|
||||||
|
onecold(pred, labels = 1:length(pred)) = labels[findfirst(pred, maximum(pred))]
|
||||||
|
|
||||||
|
using Iterators
|
||||||
|
import Iterators: partition
|
||||||
|
|
||||||
|
export partition
|
||||||
|
|
||||||
|
_partition(r::UnitRange, step::Integer) = (step*(i-1)+1:step*i for i in 1:(r.stop÷step))
|
||||||
|
_partition(xs, step) = (xs[i] for i in _partition(1:length(xs), step))
|
||||||
|
|
||||||
|
chunk(xs, n) = _partition(xs, length(xs)÷n)
|
||||||
|
|
||||||
|
batches(xs...) = (Batch(x) for x in zip(xs...))
|
||||||
|
sequences(xs, len) = (Seq(x) for x in partition(xs, len))
|
@ -1,14 +1,7 @@
|
|||||||
export AArray, onehot, onecold, chunks
|
export AArray
|
||||||
|
|
||||||
const AArray = AbstractArray
|
const AArray = AbstractArray
|
||||||
|
|
||||||
onehot(T::Type, label, labels) = T[i == label for i in labels]
|
|
||||||
onehot(label, labels) = onehot(Int, label, labels)
|
|
||||||
onecold(pred, labels = 1:length(pred)) = labels[findfirst(pred, maximum(pred))]
|
|
||||||
|
|
||||||
chunks(r::UnitRange, step::Integer) = (step*(i-1)+1:step*i for i in 1:(r.stop÷step))
|
|
||||||
chunks(xs, step) = (xs[i] for i in chunks(1:length(xs), step))
|
|
||||||
|
|
||||||
initn(dims...) = randn(Float32, dims...)/1000
|
initn(dims...) = randn(Float32, dims...)/1000
|
||||||
|
|
||||||
function train!(m, train, test = []; epoch = 1, batch = 10, η = 0.1)
|
function train!(m, train, test = []; epoch = 1, batch = 10, η = 0.1)
|
||||||
|
Loading…
Reference in New Issue
Block a user