add DataLoader

special case train! for the unsupervised data iterator
This commit is contained in:
CarloLucibello 2020-02-26 13:48:27 +01:00
parent 37af9fb15c
commit b6c79b38b4
11 changed files with 253 additions and 53 deletions

View File

@ -252,7 +252,7 @@ uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.1.0"
[[Pkg]]
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"]
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Test", "UUIDs"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
[[Printf]]

View File

@ -40,7 +40,10 @@ julia = "1"
[extras]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[targets]
test = ["Test", "Documenter"]
test = ["Test", "Documenter", "IterTools", "LinearAlgebra"]

View File

@ -15,10 +15,12 @@ makedocs(modules=[Flux, NNlib],
"Regularisation" => "models/regularisation.md",
"Model Reference" => "models/layers.md",
"NNlib" => "models/nnlib.md"],
"Handling Data" =>
["One-Hot Encoding" => "data/onehot.md",
"DataLoader" => "data/dataloader.md"],
"Training Models" =>
["Optimisers" => "training/optimisers.md",
"Training" => "training/training.md"],
"One-Hot Encoding" => "data/onehot.md",
"GPU Support" => "gpu.md",
"Saving & Loading" => "saving.md",
"Performance Tips" => "performance.md",

View File

@ -0,0 +1,6 @@
# DataLoader
Flux provides the `DataLoader` type in the `Flux.Data` module to handle iteration over mini-batches of data.
```@docs
Flux.Data.DataLoader
```

View File

@ -7,10 +7,10 @@ To actually train a model we need four things:
* A collection of data points that will be provided to the objective function.
* An [optimiser](optimisers.md) that will update the model parameters appropriately.
With these we can call `Flux.train!`:
With these we can call `train!`:
```julia
Flux.train!(objective, params, data, opt)
```@docs
Flux.Optimise.train!
```
There are plenty of examples in the [model zoo](https://github.com/FluxML/model-zoo).
@ -56,7 +56,8 @@ data = [(x, y)]
```julia
data = [(x, y), (x, y), (x, y)]
# Or equivalently
data = Iterators.repeated((x, y), 3)
using IterTools: ncycle
data = ncycle([(x, y)], 3)
```
It's common to load the `x`s and `y`s separately. In this case you can use `zip`:
@ -67,6 +68,14 @@ ys = [rand( 10), rand( 10), rand( 10)]
data = zip(xs, ys)
```
Training data can be conveniently partitioned for mini-batch training using the [`Flux.Data.DataLoader`](@ref) type:
```julia
X = rand(28, 28, 60000)
Y = rand(0:9, 60000)
data = DataLoader(X, Y, batchsize=128)
```
Note that, by default, `train!` only loops over the data once (a single "epoch").
A convenient way to run multiple epochs from the REPL is provided by `@epochs`.
@ -120,7 +129,7 @@ An example follows that works similar to the default `Flux.train` but with no ca
You don't need callbacks if you just code the calls to your functions directly into the loop.
E.g. in the places marked with comments.
```
```julia
function my_custom_train!(loss, ps, data, opt)
ps = Params(ps)
for d in data

View File

@ -7,6 +7,7 @@ using Zygote, MacroTools, Juno, Reexport, Statistics, Random
using MacroTools: @forward
@reexport using NNlib
using Zygote: Params, @adjoint, gradient, pullback, @nograd
export gradient
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,

View File

@ -3,6 +3,9 @@ module Data
import ..Flux
import SHA
using Random: shuffle!
using Base: @propagate_inbounds
export CMUDict, cmudict
deps(path...) = joinpath(@__DIR__, "..", "..", "deps", path...)
@ -26,6 +29,9 @@ function __init__()
mkpath(deps())
end
include("dataloader.jl")
export DataLoader
include("mnist.jl")
export MNIST
@ -42,7 +48,11 @@ using .Sentiment
include("iris.jl")
export Iris
<<<<<<< HEAD
include("housing.jl")
export Housing
end
=======
end #module
>>>>>>> af20a785... add DataLoader

88
src/data/dataloader.jl Normal file
View File

@ -0,0 +1,88 @@
# Adapted from Knet's src/data.jl (author: Deniz Yuret)
struct DataLoader
data
batchsize::Int
nobs::Int
partial::Bool
imax::Int
indices::Vector{Int}
shuffle::Bool
end
"""
DataLoader(data...; batchsize=1, shuffle=false, partial=true)
An object that iterates over mini-batches of `data`, each mini-batch containing `batchsize` observations
(except possibly the last one).
Takes as input one or more data tensors, e.g. X in unsupervised learning, X and Y in
supervised learning. The last dimension in each tensor is considered to be the observation
dimension.
If `shuffle=true`, shuffles the observations each time iterations are re-started.
If `partial=false`, drops the last mini-batch if it is smaller than the batchsize.
Example usage:
Xtrain = rand(10, 100)
dtrain = DataLoader(Xtrain, batchsize=2)
# iterate over 50 mini-batches
for x in dtrain:
@assert size(x) == (10, 2)
...
end
Xtrain = rand(10, 100)
Ytrain = rand(100)
dtrain = DataLoader(Xtrain, Ytrain, batchsize=2, shuffle=true)
for epoch in 1:100
for (x, y) in dtrain:
@assert size(x) == (10, 2)
@assert size(y) == (2,)
...
end
end
# train for 10 epochs
using IterTools: ncycle
Flux.train!(loss, ps, ncycle(dtrain, 10), opt)
"""
function DataLoader(data...; batchsize=1, shuffle=false, partial=true)
length(data) > 0 || throw(ArgumentError("Need at least one data input"))
batchsize > 0 || throw(ArgumentError("Need positive batchsize"))
nx = size(data[1])[end]
for i=2:length(data)
nx != size(data[i])[end] && throw(DimensionMismatch("All data should contain same number of observations"))
end
if nx < batchsize
@warn "Number of data points less than batchsize, decreasing the batchsize to $nx"
batchsize = nx
end
imax = partial ? nx : nx - batchsize + 1
ids = 1:min(nx, batchsize)
DataLoader(data, batchsize, nx, partial, imax, [1:nx;], shuffle)
end
getdata(x::AbstractArray, ids) = x[(Base.Colon() for _=1:ndims(x)-1)..., ids]
@propagate_inbounds function Base.iterate(d::DataLoader, i=0) # returns data in d.indices[i+1:i+batchsize]
i >= d.imax && return nothing
if d.shuffle && i == 0
shuffle!(d.indices)
end
nexti = min(i + d.batchsize, d.nobs)
ids = d.indices[i+1:nexti]
if length(d.data) == 1
batch = getdata(d.data[1], ids)
else
batch = ((getdata(x, ids) for x in d.data)...,)
end
return (batch, nexti)
end
function Base.length(d::DataLoader)
n = d.nobs / d.batchsize
d.partial ? ceil(Int,n) : floor(Int,n)
end

View File

@ -61,13 +61,14 @@ end
For each datapoint `d` in `data` computes the gradient of `loss(d...)` through
backpropagation and calls the optimizer `opt`.
In case datapoints `d` are of array type, assumes no splatting is needed
and computes the gradient of `loss(d)`.
Takes a callback as keyword argument `cb`. For example, this will print "training"
every 10 seconds:
```julia
Flux.train!(loss, params, data, opt,
train!(loss, params, data, opt,
cb = throttle(() -> println("training"), 10))
```
The callback can call `Flux.stop()` to interrupt the training loop.
@ -78,9 +79,15 @@ function train!(loss, ps, data, opt; cb = () -> ())
cb = runall(cb)
@progress for d in data
try
if d isa AbstractArray
gs = gradient(ps) do
loss(d)
end
else
gs = gradient(ps) do
loss(d...)
end
end
update!(opt, ps, gs)
cb()
catch ex

View File

@ -1,28 +1,85 @@
using Flux.Data
using Test
@testset "DataLoader" begin
X = reshape([1:10;], (2, 5))
Y = [1:5;]
d = DataLoader(X, batchsize=2)
batches = collect(d)
@test length(batches) == 3
@test batches[1] == X[:,1:2]
@test batches[2] == X[:,3:4]
@test batches[3] == X[:,5:5]
d = DataLoader(X, batchsize=2, partial=false)
batches = collect(d)
@test length(batches) == 2
@test batches[1] == X[:,1:2]
@test batches[2] == X[:,3:4]
d = DataLoader(X, Y, batchsize=2)
batches = collect(d)
@test length(batches) == 3
@test length(batches[1]) == 2
@test length(batches[2]) == 2
@test length(batches[3]) == 2
@test batches[1][1] == X[:,1:2]
@test batches[1][2] == Y[1:2]
@test batches[2][1] == X[:,3:4]
@test batches[2][2] == Y[3:4]
@test batches[3][1] == X[:,5:5]
@test batches[3][2] == Y[5:5]
# test interaction with `train!`
θ = ones(2)
X = zeros(2, 10)
loss(x) = sum((x .- θ).^2)
d = DataLoader(X)
Flux.train!(loss, [θ], ncycle(d, 10), Descent(0.1))
@test norm(θ) < 1e-4
# test interaction with `train!`
θ = zeros(2)
X = ones(2, 10)
Y = fill(2, 10)
loss(x, y) = sum((y - x'*θ).^2)
d = DataLoader(X, Y)
Flux.train!(loss, [θ], ncycle(d, 10), Descent(0.1))
@test norm(θ .- 1) < 1e-10
end
@testset "CMUDict" begin
@test cmudict()["CATASTROPHE"] == :[K,AH0,T,AE1,S,T,R,AH0,F,IY0].args
@test length(CMUDict.phones()) == 39
@test length(CMUDict.symbols()) == 84
end
@testset "MNIST" begin
@test MNIST.images()[1] isa Matrix
@test MNIST.labels() isa Vector{Int64}
end
@testset "FashionMNIST" begin
@test FashionMNIST.images()[1] isa Matrix
@test FashionMNIST.labels() isa Vector{Int64}
end
@testset "Sentiment" begin
@test Data.Sentiment.train() isa Vector{Data.Tree{Any}}
end
@testset "Iris" begin
@test Iris.features() isa Matrix
@test size(Iris.features()) == (4,150)
@test Iris.labels() isa Vector{String}
@test size(Iris.labels()) == (150,)
end
@testest "Housing" begin
@test Housing.features() isa Matrix
@test size(Housing.features()) == (506, 13)
@test Housing.targets() isa Array{Float64}
@test size(Housing.targets()) == (506, 1)
end

View File

@ -1,32 +1,49 @@
using Flux, Test, Random, Statistics, Documenter
using Random
using Flux
using Flux.Data
using Test
using Random, Statistics, LinearAlgebra
using Documenter
using IterTools: ncycle
Random.seed!(0)
@testset "Flux" begin
@info "Testing Basics"
@testset "Utils" begin
include("utils.jl")
end
@testset "Onehot" begin
include("onehot.jl")
end
@testset "Optimise" begin
include("optimise.jl")
end
@testset "Data" begin
include("data.jl")
end
@info "Testing Layers"
@testset "Layers" begin
include("layers/basic.jl")
include("layers/normalisation.jl")
include("layers/stateless.jl")
include("layers/conv.jl")
end
@testset "CUDA" begin
if Flux.use_cuda[]
include("cuda/cuda.jl")
else
@warn "CUDA unavailable, not testing GPU support"
end
end
@testset "Docs" begin
if VERSION >= v"1.2"
doctest(Flux)
end
end
end # testset Flux