diff --git a/src/Flux.jl b/src/Flux.jl index ff20adba..9523cdba 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -12,14 +12,15 @@ using Juno: Tree, Row # Zero Flux Given -include("dims/utils.jl") +include("utils.jl") + include("dims/catmat.jl") include("dims/batching.jl") include("dims/seq.jl") include("model.jl") -include("utils.jl") include("data.jl") +include("training.jl") include("compiler/code.jl") include("compiler/loops.jl") diff --git a/src/data.jl b/src/data.jl index 63110cc9..48df0c10 100644 --- a/src/data.jl +++ b/src/data.jl @@ -1,11 +1,5 @@ export onehot, onecold, chunk, partition, batches, sequences -mapt(f, x) = f(x) -mapt(f, xs::Tuple) = map(x -> mapt(f, x), xs) - -convertel(T::Type, xs::AbstractArray) = convert.(T, xs) -convertel{T}(::Type{T}, xs::AbstractArray{T}) = xs - """ onehot('b', ['a', 'b', 'c', 'd']) => [false, true, false, false] diff --git a/src/dims/utils.jl b/src/dims/utils.jl deleted file mode 100644 index 4d87a35a..00000000 --- a/src/dims/utils.jl +++ /dev/null @@ -1,7 +0,0 @@ -export unsqueeze - -unsqueeze(xs, dim = 1) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...)) -Base.squeeze(xs) = squeeze(xs, 1) - -stack(xs, dim = 1) = cat(dim, unsqueeze.(xs, dim)...) -unstack(xs, dim = 1) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)] diff --git a/src/training.jl b/src/training.jl new file mode 100644 index 00000000..4ead5076 --- /dev/null +++ b/src/training.jl @@ -0,0 +1,30 @@ +tobatch(xs::Batch) = rawbatch(xs) +tobatch(xs) = tobatch(batchone(xs)) + +function accuracy(m, data) + correct = 0 + for (x, y) in data + x, y = tobatch.((x, y)) + correct += sum(onecold(m(x)) .== onecold(y)) + end + return correct/length(data) +end + +function train!(m, train, test = []; + epoch = 1, η = 0.1, loss = mse) + i = 0 + for e in 1:epoch + info("Epoch $e") + @progress for (x, y) in train + x, y = tobatch.((x, y)) + i += 1 + ŷ = m(x) + any(isnan, ŷ) && error("NaN") + Δ = back!(loss, 1, ŷ, y) + back!(m, Δ, x) + update!(m, η) + i % 1000 == 0 && @show accuracy(m, test) + end + end + return m +end diff --git a/src/utils.jl b/src/utils.jl index c61e2da2..48af20a7 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,36 +1,17 @@ -export AArray +export AArray, unsqueeze const AArray = AbstractArray initn(dims...) = randn(dims...)/100 -tobatch(xs::Batch) = rawbatch(xs) -tobatch(xs) = tobatch(batchone(xs)) +unsqueeze(xs, dim = 1) = reshape(xs, (size(xs)[1:dim-1]..., 1, size(xs)[dim:end]...)) +Base.squeeze(xs) = squeeze(xs, 1) -function train!(m, train, test = []; - epoch = 1, η = 0.1, loss = mse) - i = 0 - for e in 1:epoch - info("Epoch $e") - @progress for (x, y) in train - x, y = tobatch.((x, y)) - i += 1 - ŷ = m(x) - any(isnan, ŷ) && error("NaN") - Δ = back!(loss, 1, ŷ, y) - back!(m, Δ, x) - update!(m, η) - i % 1000 == 0 && @show accuracy(m, test) - end - end - return m -end +stack(xs, dim = 1) = cat(dim, unsqueeze.(xs, dim)...) +unstack(xs, dim = 1) = [slicedim(xs, dim, i) for i = 1:size(xs, dim)] -function accuracy(m, data) - correct = 0 - for (x, y) in data - x, y = tobatch.((x, y)) - correct += sum(onecold(m(x)) .== onecold(y)) - end - return correct/length(data) -end +mapt(f, x) = f(x) +mapt(f, xs::Tuple) = map(x -> mapt(f, x), xs) + +convertel(T::Type, xs::AbstractArray) = convert.(T, xs) +convertel{T}(::Type{T}, xs::AbstractArray{T}) = xs