organise training and utils
This commit is contained in:
parent
0e6bb17709
commit
38852964f6
@ -12,14 +12,15 @@ using Juno: Tree, Row
|
|||||||
|
|
||||||
# Zero Flux Given
|
# Zero Flux Given
|
||||||
|
|
||||||
include("dims/utils.jl")
|
include("utils.jl")
|
||||||
|
|
||||||
include("dims/catmat.jl")
|
include("dims/catmat.jl")
|
||||||
include("dims/batching.jl")
|
include("dims/batching.jl")
|
||||||
include("dims/seq.jl")
|
include("dims/seq.jl")
|
||||||
|
|
||||||
include("model.jl")
|
include("model.jl")
|
||||||
include("utils.jl")
|
|
||||||
include("data.jl")
|
include("data.jl")
|
||||||
|
include("training.jl")
|
||||||
|
|
||||||
include("compiler/code.jl")
|
include("compiler/code.jl")
|
||||||
include("compiler/loops.jl")
|
include("compiler/loops.jl")
|
||||||
|
@ -1,11 +1,5 @@
|
|||||||
export onehot, onecold, chunk, partition, batches, sequences
|
export onehot, onecold, chunk, partition, batches, sequences
|
||||||
|
|
||||||
mapt(f, x) = f(x)
|
|
||||||
mapt(f, xs::Tuple) = map(x -> mapt(f, x), xs)
|
|
||||||
|
|
||||||
convertel(T::Type, xs::AbstractArray) = convert.(T, xs)
|
|
||||||
convertel{T}(::Type{T}, xs::AbstractArray{T}) = xs
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
onehot('b', ['a', 'b', 'c', 'd']) => [false, true, false, false]
|
onehot('b', ['a', 'b', 'c', 'd']) => [false, true, false, false]
|
||||||
|
|
||||||
|
@ -1,7 +0,0 @@
|
|||||||
export unsqueeze
|
|
||||||
|
|
||||||
unsqueeze(xs, dim = 1) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
|
|
||||||
Base.squeeze(xs) = squeeze(xs, 1)
|
|
||||||
|
|
||||||
stack(xs, dim = 1) = cat(dim, unsqueeze.(xs, dim)...)
|
|
||||||
unstack(xs, dim = 1) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)]
|
|
30
src/training.jl
Normal file
30
src/training.jl
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
tobatch(xs::Batch) = rawbatch(xs)
|
||||||
|
tobatch(xs) = tobatch(batchone(xs))
|
||||||
|
|
||||||
|
function accuracy(m, data)
|
||||||
|
correct = 0
|
||||||
|
for (x, y) in data
|
||||||
|
x, y = tobatch.((x, y))
|
||||||
|
correct += sum(onecold(m(x)) .== onecold(y))
|
||||||
|
end
|
||||||
|
return correct/length(data)
|
||||||
|
end
|
||||||
|
|
||||||
|
function train!(m, train, test = [];
|
||||||
|
epoch = 1, η = 0.1, loss = mse)
|
||||||
|
i = 0
|
||||||
|
for e in 1:epoch
|
||||||
|
info("Epoch $e")
|
||||||
|
@progress for (x, y) in train
|
||||||
|
x, y = tobatch.((x, y))
|
||||||
|
i += 1
|
||||||
|
ŷ = m(x)
|
||||||
|
any(isnan, ŷ) && error("NaN")
|
||||||
|
Δ = back!(loss, 1, ŷ, y)
|
||||||
|
back!(m, Δ, x)
|
||||||
|
update!(m, η)
|
||||||
|
i % 1000 == 0 && @show accuracy(m, test)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
return m
|
||||||
|
end
|
39
src/utils.jl
39
src/utils.jl
@ -1,36 +1,17 @@
|
|||||||
export AArray
|
export AArray, unsqueeze
|
||||||
|
|
||||||
const AArray = AbstractArray
|
const AArray = AbstractArray
|
||||||
|
|
||||||
initn(dims...) = randn(dims...)/100
|
initn(dims...) = randn(dims...)/100
|
||||||
|
|
||||||
tobatch(xs::Batch) = rawbatch(xs)
|
unsqueeze(xs, dim = 1) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...))
|
||||||
tobatch(xs) = tobatch(batchone(xs))
|
Base.squeeze(xs) = squeeze(xs, 1)
|
||||||
|
|
||||||
function train!(m, train, test = [];
|
stack(xs, dim = 1) = cat(dim, unsqueeze.(xs, dim)...)
|
||||||
epoch = 1, η = 0.1, loss = mse)
|
unstack(xs, dim = 1) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)]
|
||||||
i = 0
|
|
||||||
for e in 1:epoch
|
|
||||||
info("Epoch $e")
|
|
||||||
@progress for (x, y) in train
|
|
||||||
x, y = tobatch.((x, y))
|
|
||||||
i += 1
|
|
||||||
ŷ = m(x)
|
|
||||||
any(isnan, ŷ) && error("NaN")
|
|
||||||
Δ = back!(loss, 1, ŷ, y)
|
|
||||||
back!(m, Δ, x)
|
|
||||||
update!(m, η)
|
|
||||||
i % 1000 == 0 && @show accuracy(m, test)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
return m
|
|
||||||
end
|
|
||||||
|
|
||||||
function accuracy(m, data)
|
mapt(f, x) = f(x)
|
||||||
correct = 0
|
mapt(f, xs::Tuple) = map(x -> mapt(f, x), xs)
|
||||||
for (x, y) in data
|
|
||||||
x, y = tobatch.((x, y))
|
convertel(T::Type, xs::AbstractArray) = convert.(T, xs)
|
||||||
correct += sum(onecold(m(x)) .== onecold(y))
|
convertel{T}(::Type{T}, xs::AbstractArray{T}) = xs
|
||||||
end
|
|
||||||
return correct/length(data)
|
|
||||||
end
|
|
||||||
|
Loading…
Reference in New Issue
Block a user