organise training and utils

This commit is contained in:
Mike J Innes 2017-05-01 12:41:54 +01:00
parent 0e6bb17709
commit 38852964f6
5 changed files with 43 additions and 44 deletions

View File

@ -12,14 +12,15 @@ using Juno: Tree, Row
# Zero Flux Given
include("dims/utils.jl")
include("utils.jl")
include("dims/catmat.jl")
include("dims/batching.jl")
include("dims/seq.jl")
include("model.jl")
include("utils.jl")
include("data.jl")
include("training.jl")
include("compiler/code.jl")
include("compiler/loops.jl")

View File

@ -1,11 +1,5 @@
export onehot, onecold, chunk, partition, batches, sequences
mapt(f, x) = f(x)
mapt(f, xs::Tuple) = map(x -> mapt(f, x), xs)
convertel(T::Type, xs::AbstractArray) = convert.(T, xs)
convertel{T}(::Type{T}, xs::AbstractArray{T}) = xs
"""
onehot('b', ['a', 'b', 'c', 'd']) => [false, true, false, false]

View File

@ -1,7 +0,0 @@
export unsqueeze
unsqueeze(xs, dim = 1) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
Base.squeeze(xs) = squeeze(xs, 1)
stack(xs, dim = 1) = cat(dim, unsqueeze.(xs, dim)...)
unstack(xs, dim = 1) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)]

30
src/training.jl Normal file
View File

@ -0,0 +1,30 @@
tobatch(xs::Batch) = rawbatch(xs)
tobatch(xs) = tobatch(batchone(xs))
function accuracy(m, data)
correct = 0
for (x, y) in data
x, y = tobatch.((x, y))
correct += sum(onecold(m(x)) .== onecold(y))
end
return correct/length(data)
end
function train!(m, train, test = [];
epoch = 1, η = 0.1, loss = mse)
i = 0
for e in 1:epoch
info("Epoch $e")
@progress for (x, y) in train
x, y = tobatch.((x, y))
i += 1
= m(x)
any(isnan, ) && error("NaN")
Δ = back!(loss, 1, , y)
back!(m, Δ, x)
update!(m, η)
i % 1000 == 0 && @show accuracy(m, test)
end
end
return m
end

View File

@ -1,36 +1,17 @@
export AArray
export AArray, unsqueeze
const AArray = AbstractArray
initn(dims...) = randn(dims...)/100
tobatch(xs::Batch) = rawbatch(xs)
tobatch(xs) = tobatch(batchone(xs))
unsqueeze(xs, dim = 1) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
Base.squeeze(xs) = squeeze(xs, 1)
function train!(m, train, test = [];
epoch = 1, η = 0.1, loss = mse)
i = 0
for e in 1:epoch
info("Epoch $e")
@progress for (x, y) in train
x, y = tobatch.((x, y))
i += 1
= m(x)
any(isnan, ) && error("NaN")
Δ = back!(loss, 1, , y)
back!(m, Δ, x)
update!(m, η)
i % 1000 == 0 && @show accuracy(m, test)
end
end
return m
end
stack(xs, dim = 1) = cat(dim, unsqueeze.(xs, dim)...)
unstack(xs, dim = 1) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)]
function accuracy(m, data)
correct = 0
for (x, y) in data
x, y = tobatch.((x, y))
correct += sum(onecold(m(x)) .== onecold(y))
end
return correct/length(data)
end
mapt(f, x) = f(x)
mapt(f, xs::Tuple) = map(x -> mapt(f, x), xs)
convertel(T::Type, xs::AbstractArray) = convert.(T, xs)
convertel{T}(::Type{T}, xs::AbstractArray{T}) = xs