From b6c79b38b4bf54aba0ee096b38afd1180ad1ee55 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Wed, 26 Feb 2020 13:48:27 +0100 Subject: [PATCH 1/5] add DataLoader special case train! for the unsupervised data iterator --- Manifest.toml | 2 +- Project.toml | 5 +- docs/make.jl | 4 +- docs/src/data/dataloader.md | 6 +++ docs/src/training/training.md | 19 +++++-- src/Flux.jl | 1 + src/data/Data.jl | 10 ++++ src/data/dataloader.jl | 88 +++++++++++++++++++++++++++++++++ src/optimise/train.jl | 19 ++++--- test/data.jl | 93 ++++++++++++++++++++++++++++------- test/runtests.jl | 59 ++++++++++++++-------- 11 files changed, 253 insertions(+), 53 deletions(-) create mode 100644 docs/src/data/dataloader.md create mode 100644 src/data/dataloader.jl diff --git a/Manifest.toml b/Manifest.toml index 693f7ca2..788e5354 100644 --- a/Manifest.toml +++ b/Manifest.toml @@ -252,7 +252,7 @@ uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" version = "1.1.0" [[Pkg]] -deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] +deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Test", "UUIDs"] uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" [[Printf]] diff --git a/Project.toml b/Project.toml index 71282a10..bd105730 100644 --- a/Project.toml +++ b/Project.toml @@ -40,7 +40,10 @@ julia = "1" [extras] Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" +IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" + [targets] -test = ["Test", "Documenter"] +test = ["Test", "Documenter", "IterTools", "LinearAlgebra"] diff --git a/docs/make.jl b/docs/make.jl index fe3544fc..0d597500 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -15,10 +15,12 @@ makedocs(modules=[Flux, NNlib], "Regularisation" => "models/regularisation.md", "Model Reference" => "models/layers.md", "NNlib" => "models/nnlib.md"], + "Handling Data" => + ["One-Hot Encoding" => "data/onehot.md", + "DataLoader" => "data/dataloader.md"], "Training Models" => ["Optimisers" => "training/optimisers.md", "Training" => "training/training.md"], - "One-Hot Encoding" => "data/onehot.md", "GPU Support" => "gpu.md", "Saving & Loading" => "saving.md", "Performance Tips" => "performance.md", diff --git a/docs/src/data/dataloader.md b/docs/src/data/dataloader.md new file mode 100644 index 00000000..70a883c9 --- /dev/null +++ b/docs/src/data/dataloader.md @@ -0,0 +1,6 @@ +# DataLoader +Flux provides the `DataLoader` type in the `Flux.Data` module to handle iteration over mini-batches of data. + +```@docs +Flux.Data.DataLoader +``` \ No newline at end of file diff --git a/docs/src/training/training.md b/docs/src/training/training.md index b42db7c9..64b2b5e8 100644 --- a/docs/src/training/training.md +++ b/docs/src/training/training.md @@ -7,10 +7,10 @@ To actually train a model we need four things: * A collection of data points that will be provided to the objective function. * An [optimiser](optimisers.md) that will update the model parameters appropriately. -With these we can call `Flux.train!`: +With these we can call `train!`: -```julia -Flux.train!(objective, params, data, opt) +```@docs +Flux.Optimise.train! ``` There are plenty of examples in the [model zoo](https://github.com/FluxML/model-zoo). @@ -56,7 +56,8 @@ data = [(x, y)] ```julia data = [(x, y), (x, y), (x, y)] # Or equivalently -data = Iterators.repeated((x, y), 3) +using IterTools: ncycle +data = ncycle([(x, y)], 3) ``` It's common to load the `x`s and `y`s separately. In this case you can use `zip`: @@ -67,6 +68,14 @@ ys = [rand( 10), rand( 10), rand( 10)] data = zip(xs, ys) ``` +Training data can be conveniently partitioned for mini-batch training using the [`Flux.Data.DataLoader`](@ref) type: + +```julia +X = rand(28, 28, 60000) +Y = rand(0:9, 60000) +data = DataLoader(X, Y, batchsize=128) +``` + Note that, by default, `train!` only loops over the data once (a single "epoch"). A convenient way to run multiple epochs from the REPL is provided by `@epochs`. @@ -120,7 +129,7 @@ An example follows that works similar to the default `Flux.train` but with no ca You don't need callbacks if you just code the calls to your functions directly into the loop. E.g. in the places marked with comments. -``` +```julia function my_custom_train!(loss, ps, data, opt) ps = Params(ps) for d in data diff --git a/src/Flux.jl b/src/Flux.jl index 9969b323..c99e41a1 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -7,6 +7,7 @@ using Zygote, MacroTools, Juno, Reexport, Statistics, Random using MacroTools: @forward @reexport using NNlib using Zygote: Params, @adjoint, gradient, pullback, @nograd + export gradient export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool, diff --git a/src/data/Data.jl b/src/data/Data.jl index 88af9549..940b7ea7 100644 --- a/src/data/Data.jl +++ b/src/data/Data.jl @@ -3,6 +3,9 @@ module Data import ..Flux import SHA +using Random: shuffle! +using Base: @propagate_inbounds + export CMUDict, cmudict deps(path...) = joinpath(@__DIR__, "..", "..", "deps", path...) @@ -26,6 +29,9 @@ function __init__() mkpath(deps()) end +include("dataloader.jl") +export DataLoader + include("mnist.jl") export MNIST @@ -42,7 +48,11 @@ using .Sentiment include("iris.jl") export Iris +<<<<<<< HEAD include("housing.jl") export Housing end +======= +end #module +>>>>>>> af20a785... add DataLoader diff --git a/src/data/dataloader.jl b/src/data/dataloader.jl new file mode 100644 index 00000000..baf32a83 --- /dev/null +++ b/src/data/dataloader.jl @@ -0,0 +1,88 @@ +# Adapted from Knet's src/data.jl (author: Deniz Yuret) + +struct DataLoader + data + batchsize::Int + nobs::Int + partial::Bool + imax::Int + indices::Vector{Int} + shuffle::Bool +end + +""" + DataLoader(data...; batchsize=1, shuffle=false, partial=true) + +An object that iterates over mini-batches of `data`, each mini-batch containing `batchsize` observations +(except possibly the last one). + +Takes as input one or more data tensors, e.g. X in unsupervised learning, X and Y in +supervised learning. The last dimension in each tensor is considered to be the observation +dimension. + +If `shuffle=true`, shuffles the observations each time iterations are re-started. +If `partial=false`, drops the last mini-batch if it is smaller than the batchsize. + +Example usage: + + Xtrain = rand(10, 100) + dtrain = DataLoader(Xtrain, batchsize=2) + # iterate over 50 mini-batches + for x in dtrain: + @assert size(x) == (10, 2) + ... + end + + Xtrain = rand(10, 100) + Ytrain = rand(100) + dtrain = DataLoader(Xtrain, Ytrain, batchsize=2, shuffle=true) + for epoch in 1:100 + for (x, y) in dtrain: + @assert size(x) == (10, 2) + @assert size(y) == (2,) + ... + end + end + + # train for 10 epochs + using IterTools: ncycle + Flux.train!(loss, ps, ncycle(dtrain, 10), opt) +""" +function DataLoader(data...; batchsize=1, shuffle=false, partial=true) + length(data) > 0 || throw(ArgumentError("Need at least one data input")) + batchsize > 0 || throw(ArgumentError("Need positive batchsize")) + + nx = size(data[1])[end] + for i=2:length(data) + nx != size(data[i])[end] && throw(DimensionMismatch("All data should contain same number of observations")) + end + if nx < batchsize + @warn "Number of data points less than batchsize, decreasing the batchsize to $nx" + batchsize = nx + end + imax = partial ? nx : nx - batchsize + 1 + ids = 1:min(nx, batchsize) + DataLoader(data, batchsize, nx, partial, imax, [1:nx;], shuffle) +end + +getdata(x::AbstractArray, ids) = x[(Base.Colon() for _=1:ndims(x)-1)..., ids] + +@propagate_inbounds function Base.iterate(d::DataLoader, i=0) # returns data in d.indices[i+1:i+batchsize] + i >= d.imax && return nothing + if d.shuffle && i == 0 + shuffle!(d.indices) + end + nexti = min(i + d.batchsize, d.nobs) + ids = d.indices[i+1:nexti] + if length(d.data) == 1 + batch = getdata(d.data[1], ids) + else + batch = ((getdata(x, ids) for x in d.data)...,) + end + return (batch, nexti) +end + +function Base.length(d::DataLoader) + n = d.nobs / d.batchsize + d.partial ? ceil(Int,n) : floor(Int,n) +end \ No newline at end of file diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 59404a42..34a98394 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -61,13 +61,14 @@ end For each datapoint `d` in `data` computes the gradient of `loss(d...)` through backpropagation and calls the optimizer `opt`. +In case datapoints `d` are of array type, assumes no splatting is needed +and computes the gradient of `loss(d)`. + Takes a callback as keyword argument `cb`. For example, this will print "training" every 10 seconds: -```julia -Flux.train!(loss, params, data, opt, - cb = throttle(() -> println("training"), 10)) -``` + train!(loss, params, data, opt, + cb = throttle(() -> println("training"), 10)) The callback can call `Flux.stop()` to interrupt the training loop. @@ -78,8 +79,14 @@ function train!(loss, ps, data, opt; cb = () -> ()) cb = runall(cb) @progress for d in data try - gs = gradient(ps) do - loss(d...) + if d isa AbstractArray + gs = gradient(ps) do + loss(d) + end + else + gs = gradient(ps) do + loss(d...) + end end update!(opt, ps, gs) cb() diff --git a/test/data.jl b/test/data.jl index 6c012a93..1a090174 100644 --- a/test/data.jl +++ b/test/data.jl @@ -1,28 +1,85 @@ -using Flux.Data -using Test +@testset "DataLoader" begin + X = reshape([1:10;], (2, 5)) + Y = [1:5;] -@test cmudict()["CATASTROPHE"] == :[K,AH0,T,AE1,S,T,R,AH0,F,IY0].args + d = DataLoader(X, batchsize=2) + batches = collect(d) + @test length(batches) == 3 + @test batches[1] == X[:,1:2] + @test batches[2] == X[:,3:4] + @test batches[3] == X[:,5:5] -@test length(CMUDict.phones()) == 39 + d = DataLoader(X, batchsize=2, partial=false) + batches = collect(d) + @test length(batches) == 2 + @test batches[1] == X[:,1:2] + @test batches[2] == X[:,3:4] -@test length(CMUDict.symbols()) == 84 + d = DataLoader(X, Y, batchsize=2) + batches = collect(d) + @test length(batches) == 3 + @test length(batches[1]) == 2 + @test length(batches[2]) == 2 + @test length(batches[3]) == 2 + @test batches[1][1] == X[:,1:2] + @test batches[1][2] == Y[1:2] + @test batches[2][1] == X[:,3:4] + @test batches[2][2] == Y[3:4] + @test batches[3][1] == X[:,5:5] + @test batches[3][2] == Y[5:5] -@test MNIST.images()[1] isa Matrix -@test MNIST.labels() isa Vector{Int64} + # test interaction with `train!` + θ = ones(2) + X = zeros(2, 10) + loss(x) = sum((x .- θ).^2) + d = DataLoader(X) + Flux.train!(loss, [θ], ncycle(d, 10), Descent(0.1)) + @test norm(θ) < 1e-4 -@test FashionMNIST.images()[1] isa Matrix -@test FashionMNIST.labels() isa Vector{Int64} + # test interaction with `train!` + θ = zeros(2) + X = ones(2, 10) + Y = fill(2, 10) + loss(x, y) = sum((y - x'*θ).^2) + d = DataLoader(X, Y) + Flux.train!(loss, [θ], ncycle(d, 10), Descent(0.1)) + @test norm(θ .- 1) < 1e-10 +end -@test Data.Sentiment.train() isa Vector{Data.Tree{Any}} +@testset "CMUDict" begin + @test cmudict()["CATASTROPHE"] == :[K,AH0,T,AE1,S,T,R,AH0,F,IY0].args -@test Iris.features() isa Matrix -@test size(Iris.features()) == (4,150) + @test length(CMUDict.phones()) == 39 -@test Iris.labels() isa Vector{String} -@test size(Iris.labels()) == (150,) + @test length(CMUDict.symbols()) == 84 +end -@test Housing.features() isa Matrix -@test size(Housing.features()) == (506, 13) +@testset "MNIST" begin + @test MNIST.images()[1] isa Matrix + @test MNIST.labels() isa Vector{Int64} +end -@test Housing.targets() isa Array{Float64} -@test size(Housing.targets()) == (506, 1) +@testset "FashionMNIST" begin + @test FashionMNIST.images()[1] isa Matrix + @test FashionMNIST.labels() isa Vector{Int64} +end + +@testset "Sentiment" begin + @test Data.Sentiment.train() isa Vector{Data.Tree{Any}} +end + +@testset "Iris" begin + @test Iris.features() isa Matrix + @test size(Iris.features()) == (4,150) + + @test Iris.labels() isa Vector{String} + @test size(Iris.labels()) == (150,) +end + +@testest "Housing" begin + @test Housing.features() isa Matrix + @test size(Housing.features()) == (506, 13) + + @test Housing.targets() isa Array{Float64} + @test size(Housing.targets()) == (506, 1) +end diff --git a/test/runtests.jl b/test/runtests.jl index 1505e96a..81182f0d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,32 +1,49 @@ -using Flux, Test, Random, Statistics, Documenter -using Random +using Flux +using Flux.Data +using Test +using Random, Statistics, LinearAlgebra +using Documenter +using IterTools: ncycle Random.seed!(0) @testset "Flux" begin -@info "Testing Basics" + @testset "Utils" begin + include("utils.jl") + end -include("utils.jl") -include("onehot.jl") -include("optimise.jl") -include("data.jl") + @testset "Onehot" begin + include("onehot.jl") + end -@info "Testing Layers" + @testset "Optimise" begin + include("optimise.jl") + end -include("layers/basic.jl") -include("layers/normalisation.jl") -include("layers/stateless.jl") -include("layers/conv.jl") + @testset "Data" begin + include("data.jl") + end -if Flux.use_cuda[] - include("cuda/cuda.jl") -else - @warn "CUDA unavailable, not testing GPU support" -end + @testset "Layers" begin + include("layers/basic.jl") + include("layers/normalisation.jl") + include("layers/stateless.jl") + include("layers/conv.jl") + end -if VERSION >= v"1.2" - doctest(Flux) -end + @testset "CUDA" begin + if Flux.use_cuda[] + include("cuda/cuda.jl") + else + @warn "CUDA unavailable, not testing GPU support" + end + end -end + @testset "Docs" begin + if VERSION >= v"1.2" + doctest(Flux) + end + end + +end # testset Flux From 487002878ed530303cf9527e7cca0ea57b34d5b2 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Thu, 27 Feb 2020 20:49:05 +0100 Subject: [PATCH 2/5] restrict train! special casing --- src/optimise/train.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 34a98394..54b7f53a 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -79,7 +79,7 @@ function train!(loss, ps, data, opt; cb = () -> ()) cb = runall(cb) @progress for d in data try - if d isa AbstractArray + if d isa AbstractArray{<:Number} gs = gradient(ps) do loss(d) end From 97141e8c98fc94feadbe287f45a32b58bd3d515c Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Thu, 27 Feb 2020 20:49:55 +0100 Subject: [PATCH 3/5] improve docstring --- src/optimise/train.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 54b7f53a..79ebcc06 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -61,7 +61,7 @@ end For each datapoint `d` in `data` computes the gradient of `loss(d...)` through backpropagation and calls the optimizer `opt`. -In case datapoints `d` are of array type, assumes no splatting is needed +In case datapoints `d` are of numeric array type, assumes no splatting is needed and computes the gradient of `loss(d)`. Takes a callback as keyword argument `cb`. For example, this will print "training" From a72258ea2a428ce4b12e711395856091f17f9fcc Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Sat, 29 Feb 2020 18:55:49 +0100 Subject: [PATCH 4/5] fix rebase --- src/data/Data.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/data/Data.jl b/src/data/Data.jl index 940b7ea7..16a025a7 100644 --- a/src/data/Data.jl +++ b/src/data/Data.jl @@ -48,11 +48,7 @@ using .Sentiment include("iris.jl") export Iris -<<<<<<< HEAD include("housing.jl") export Housing end -======= -end #module ->>>>>>> af20a785... add DataLoader From a1efc434c21d2e4026e5d4f8764854451bac88c5 Mon Sep 17 00:00:00 2001 From: CarloLucibello Date: Sat, 29 Feb 2020 19:40:44 +0100 Subject: [PATCH 5/5] fix typo --- test/data.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/data.jl b/test/data.jl index 1a090174..c7a8fdfd 100644 --- a/test/data.jl +++ b/test/data.jl @@ -76,7 +76,7 @@ end @test size(Iris.labels()) == (150,) end -@testest "Housing" begin +@testset "Housing" begin @test Housing.features() isa Matrix @test size(Housing.features()) == (506, 13)