add DataLoader

special case train! for the unsupervised data iterator
This commit is contained in:
CarloLucibello 2020-02-26 13:48:27 +01:00
parent 37af9fb15c
commit b6c79b38b4
11 changed files with 253 additions and 53 deletions

View File

@ -252,7 +252,7 @@ uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
version = "1.1.0" version = "1.1.0"
[[Pkg]] [[Pkg]]
deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "UUIDs"] deps = ["Dates", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Test", "UUIDs"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
[[Printf]] [[Printf]]

View File

@ -40,7 +40,10 @@ julia = "1"
[extras] [extras]
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
[targets] [targets]
test = ["Test", "Documenter"] test = ["Test", "Documenter", "IterTools", "LinearAlgebra"]

View File

@ -15,10 +15,12 @@ makedocs(modules=[Flux, NNlib],
"Regularisation" => "models/regularisation.md", "Regularisation" => "models/regularisation.md",
"Model Reference" => "models/layers.md", "Model Reference" => "models/layers.md",
"NNlib" => "models/nnlib.md"], "NNlib" => "models/nnlib.md"],
"Handling Data" =>
["One-Hot Encoding" => "data/onehot.md",
"DataLoader" => "data/dataloader.md"],
"Training Models" => "Training Models" =>
["Optimisers" => "training/optimisers.md", ["Optimisers" => "training/optimisers.md",
"Training" => "training/training.md"], "Training" => "training/training.md"],
"One-Hot Encoding" => "data/onehot.md",
"GPU Support" => "gpu.md", "GPU Support" => "gpu.md",
"Saving & Loading" => "saving.md", "Saving & Loading" => "saving.md",
"Performance Tips" => "performance.md", "Performance Tips" => "performance.md",

View File

@ -0,0 +1,6 @@
# DataLoader
Flux provides the `DataLoader` type in the `Flux.Data` module to handle iteration over mini-batches of data.
```@docs
Flux.Data.DataLoader
```

View File

@ -7,10 +7,10 @@ To actually train a model we need four things:
* A collection of data points that will be provided to the objective function. * A collection of data points that will be provided to the objective function.
* An [optimiser](optimisers.md) that will update the model parameters appropriately. * An [optimiser](optimisers.md) that will update the model parameters appropriately.
With these we can call `Flux.train!`: With these we can call `train!`:
```julia ```@docs
Flux.train!(objective, params, data, opt) Flux.Optimise.train!
``` ```
There are plenty of examples in the [model zoo](https://github.com/FluxML/model-zoo). There are plenty of examples in the [model zoo](https://github.com/FluxML/model-zoo).
@ -56,7 +56,8 @@ data = [(x, y)]
```julia ```julia
data = [(x, y), (x, y), (x, y)] data = [(x, y), (x, y), (x, y)]
# Or equivalently # Or equivalently
data = Iterators.repeated((x, y), 3) using IterTools: ncycle
data = ncycle([(x, y)], 3)
``` ```
It's common to load the `x`s and `y`s separately. In this case you can use `zip`: It's common to load the `x`s and `y`s separately. In this case you can use `zip`:
@ -67,6 +68,14 @@ ys = [rand( 10), rand( 10), rand( 10)]
data = zip(xs, ys) data = zip(xs, ys)
``` ```
Training data can be conveniently partitioned for mini-batch training using the [`Flux.Data.DataLoader`](@ref) type:
```julia
X = rand(28, 28, 60000)
Y = rand(0:9, 60000)
data = DataLoader(X, Y, batchsize=128)
```
Note that, by default, `train!` only loops over the data once (a single "epoch"). Note that, by default, `train!` only loops over the data once (a single "epoch").
A convenient way to run multiple epochs from the REPL is provided by `@epochs`. A convenient way to run multiple epochs from the REPL is provided by `@epochs`.
@ -120,7 +129,7 @@ An example follows that works similar to the default `Flux.train` but with no ca
You don't need callbacks if you just code the calls to your functions directly into the loop. You don't need callbacks if you just code the calls to your functions directly into the loop.
E.g. in the places marked with comments. E.g. in the places marked with comments.
``` ```julia
function my_custom_train!(loss, ps, data, opt) function my_custom_train!(loss, ps, data, opt)
ps = Params(ps) ps = Params(ps)
for d in data for d in data

View File

@ -7,6 +7,7 @@ using Zygote, MacroTools, Juno, Reexport, Statistics, Random
using MacroTools: @forward using MacroTools: @forward
@reexport using NNlib @reexport using NNlib
using Zygote: Params, @adjoint, gradient, pullback, @nograd using Zygote: Params, @adjoint, gradient, pullback, @nograd
export gradient export gradient
export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool, export Chain, Dense, Maxout, RNN, LSTM, GRU, Conv, CrossCor, ConvTranspose, MaxPool, MeanPool,

View File

@ -3,6 +3,9 @@ module Data
import ..Flux import ..Flux
import SHA import SHA
using Random: shuffle!
using Base: @propagate_inbounds
export CMUDict, cmudict export CMUDict, cmudict
deps(path...) = joinpath(@__DIR__, "..", "..", "deps", path...) deps(path...) = joinpath(@__DIR__, "..", "..", "deps", path...)
@ -26,6 +29,9 @@ function __init__()
mkpath(deps()) mkpath(deps())
end end
include("dataloader.jl")
export DataLoader
include("mnist.jl") include("mnist.jl")
export MNIST export MNIST
@ -42,7 +48,11 @@ using .Sentiment
include("iris.jl") include("iris.jl")
export Iris export Iris
<<<<<<< HEAD
include("housing.jl") include("housing.jl")
export Housing export Housing
end end
=======
end #module
>>>>>>> af20a785... add DataLoader

88
src/data/dataloader.jl Normal file
View File

@ -0,0 +1,88 @@
# Adapted from Knet's src/data.jl (author: Deniz Yuret)
struct DataLoader
data
batchsize::Int
nobs::Int
partial::Bool
imax::Int
indices::Vector{Int}
shuffle::Bool
end
"""
DataLoader(data...; batchsize=1, shuffle=false, partial=true)
An object that iterates over mini-batches of `data`, each mini-batch containing `batchsize` observations
(except possibly the last one).
Takes as input one or more data tensors, e.g. X in unsupervised learning, X and Y in
supervised learning. The last dimension in each tensor is considered to be the observation
dimension.
If `shuffle=true`, shuffles the observations each time iterations are re-started.
If `partial=false`, drops the last mini-batch if it is smaller than the batchsize.
Example usage:
Xtrain = rand(10, 100)
dtrain = DataLoader(Xtrain, batchsize=2)
# iterate over 50 mini-batches
for x in dtrain:
@assert size(x) == (10, 2)
...
end
Xtrain = rand(10, 100)
Ytrain = rand(100)
dtrain = DataLoader(Xtrain, Ytrain, batchsize=2, shuffle=true)
for epoch in 1:100
for (x, y) in dtrain:
@assert size(x) == (10, 2)
@assert size(y) == (2,)
...
end
end
# train for 10 epochs
using IterTools: ncycle
Flux.train!(loss, ps, ncycle(dtrain, 10), opt)
"""
function DataLoader(data...; batchsize=1, shuffle=false, partial=true)
length(data) > 0 || throw(ArgumentError("Need at least one data input"))
batchsize > 0 || throw(ArgumentError("Need positive batchsize"))
nx = size(data[1])[end]
for i=2:length(data)
nx != size(data[i])[end] && throw(DimensionMismatch("All data should contain same number of observations"))
end
if nx < batchsize
@warn "Number of data points less than batchsize, decreasing the batchsize to $nx"
batchsize = nx
end
imax = partial ? nx : nx - batchsize + 1
ids = 1:min(nx, batchsize)
DataLoader(data, batchsize, nx, partial, imax, [1:nx;], shuffle)
end
getdata(x::AbstractArray, ids) = x[(Base.Colon() for _=1:ndims(x)-1)..., ids]
@propagate_inbounds function Base.iterate(d::DataLoader, i=0) # returns data in d.indices[i+1:i+batchsize]
i >= d.imax && return nothing
if d.shuffle && i == 0
shuffle!(d.indices)
end
nexti = min(i + d.batchsize, d.nobs)
ids = d.indices[i+1:nexti]
if length(d.data) == 1
batch = getdata(d.data[1], ids)
else
batch = ((getdata(x, ids) for x in d.data)...,)
end
return (batch, nexti)
end
function Base.length(d::DataLoader)
n = d.nobs / d.batchsize
d.partial ? ceil(Int,n) : floor(Int,n)
end

View File

@ -61,13 +61,14 @@ end
For each datapoint `d` in `data` computes the gradient of `loss(d...)` through For each datapoint `d` in `data` computes the gradient of `loss(d...)` through
backpropagation and calls the optimizer `opt`. backpropagation and calls the optimizer `opt`.
In case datapoints `d` are of array type, assumes no splatting is needed
and computes the gradient of `loss(d)`.
Takes a callback as keyword argument `cb`. For example, this will print "training" Takes a callback as keyword argument `cb`. For example, this will print "training"
every 10 seconds: every 10 seconds:
```julia train!(loss, params, data, opt,
Flux.train!(loss, params, data, opt, cb = throttle(() -> println("training"), 10))
cb = throttle(() -> println("training"), 10))
```
The callback can call `Flux.stop()` to interrupt the training loop. The callback can call `Flux.stop()` to interrupt the training loop.
@ -78,8 +79,14 @@ function train!(loss, ps, data, opt; cb = () -> ())
cb = runall(cb) cb = runall(cb)
@progress for d in data @progress for d in data
try try
gs = gradient(ps) do if d isa AbstractArray
loss(d...) gs = gradient(ps) do
loss(d)
end
else
gs = gradient(ps) do
loss(d...)
end
end end
update!(opt, ps, gs) update!(opt, ps, gs)
cb() cb()

View File

@ -1,28 +1,85 @@
using Flux.Data @testset "DataLoader" begin
using Test X = reshape([1:10;], (2, 5))
Y = [1:5;]
@test cmudict()["CATASTROPHE"] == :[K,AH0,T,AE1,S,T,R,AH0,F,IY0].args d = DataLoader(X, batchsize=2)
batches = collect(d)
@test length(batches) == 3
@test batches[1] == X[:,1:2]
@test batches[2] == X[:,3:4]
@test batches[3] == X[:,5:5]
@test length(CMUDict.phones()) == 39 d = DataLoader(X, batchsize=2, partial=false)
batches = collect(d)
@test length(batches) == 2
@test batches[1] == X[:,1:2]
@test batches[2] == X[:,3:4]
@test length(CMUDict.symbols()) == 84 d = DataLoader(X, Y, batchsize=2)
batches = collect(d)
@test length(batches) == 3
@test length(batches[1]) == 2
@test length(batches[2]) == 2
@test length(batches[3]) == 2
@test batches[1][1] == X[:,1:2]
@test batches[1][2] == Y[1:2]
@test batches[2][1] == X[:,3:4]
@test batches[2][2] == Y[3:4]
@test batches[3][1] == X[:,5:5]
@test batches[3][2] == Y[5:5]
@test MNIST.images()[1] isa Matrix # test interaction with `train!`
@test MNIST.labels() isa Vector{Int64} θ = ones(2)
X = zeros(2, 10)
loss(x) = sum((x .- θ).^2)
d = DataLoader(X)
Flux.train!(loss, [θ], ncycle(d, 10), Descent(0.1))
@test norm(θ) < 1e-4
@test FashionMNIST.images()[1] isa Matrix # test interaction with `train!`
@test FashionMNIST.labels() isa Vector{Int64} θ = zeros(2)
X = ones(2, 10)
Y = fill(2, 10)
loss(x, y) = sum((y - x'*θ).^2)
d = DataLoader(X, Y)
Flux.train!(loss, [θ], ncycle(d, 10), Descent(0.1))
@test norm(θ .- 1) < 1e-10
end
@test Data.Sentiment.train() isa Vector{Data.Tree{Any}} @testset "CMUDict" begin
@test cmudict()["CATASTROPHE"] == :[K,AH0,T,AE1,S,T,R,AH0,F,IY0].args
@test Iris.features() isa Matrix @test length(CMUDict.phones()) == 39
@test size(Iris.features()) == (4,150)
@test Iris.labels() isa Vector{String} @test length(CMUDict.symbols()) == 84
@test size(Iris.labels()) == (150,) end
@test Housing.features() isa Matrix @testset "MNIST" begin
@test size(Housing.features()) == (506, 13) @test MNIST.images()[1] isa Matrix
@test MNIST.labels() isa Vector{Int64}
end
@test Housing.targets() isa Array{Float64} @testset "FashionMNIST" begin
@test size(Housing.targets()) == (506, 1) @test FashionMNIST.images()[1] isa Matrix
@test FashionMNIST.labels() isa Vector{Int64}
end
@testset "Sentiment" begin
@test Data.Sentiment.train() isa Vector{Data.Tree{Any}}
end
@testset "Iris" begin
@test Iris.features() isa Matrix
@test size(Iris.features()) == (4,150)
@test Iris.labels() isa Vector{String}
@test size(Iris.labels()) == (150,)
end
@testest "Housing" begin
@test Housing.features() isa Matrix
@test size(Housing.features()) == (506, 13)
@test Housing.targets() isa Array{Float64}
@test size(Housing.targets()) == (506, 1)
end

View File

@ -1,32 +1,49 @@
using Flux, Test, Random, Statistics, Documenter using Flux
using Random using Flux.Data
using Test
using Random, Statistics, LinearAlgebra
using Documenter
using IterTools: ncycle
Random.seed!(0) Random.seed!(0)
@testset "Flux" begin @testset "Flux" begin
@info "Testing Basics" @testset "Utils" begin
include("utils.jl")
end
include("utils.jl") @testset "Onehot" begin
include("onehot.jl") include("onehot.jl")
include("optimise.jl") end
include("data.jl")
@info "Testing Layers" @testset "Optimise" begin
include("optimise.jl")
end
include("layers/basic.jl") @testset "Data" begin
include("layers/normalisation.jl") include("data.jl")
include("layers/stateless.jl") end
include("layers/conv.jl")
if Flux.use_cuda[] @testset "Layers" begin
include("cuda/cuda.jl") include("layers/basic.jl")
else include("layers/normalisation.jl")
@warn "CUDA unavailable, not testing GPU support" include("layers/stateless.jl")
end include("layers/conv.jl")
end
if VERSION >= v"1.2" @testset "CUDA" begin
doctest(Flux) if Flux.use_cuda[]
end include("cuda/cuda.jl")
else
@warn "CUDA unavailable, not testing GPU support"
end
end
end @testset "Docs" begin
if VERSION >= v"1.2"
doctest(Flux)
end
end
end # testset Flux