extend dataloader

This commit is contained in:
CarloLucibello 2020-04-29 10:18:16 +02:00
parent 792a1c54f8
commit a643cb6758
7 changed files with 92 additions and 70 deletions

View File

@ -1,6 +1,6 @@
name = "Flux" name = "Flux"
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c" uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
version = "0.10.5" version = "0.11.0"
[deps] [deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"

View File

@ -51,4 +51,6 @@ export Iris
include("housing.jl") include("housing.jl")
export Housing export Housing
@deprecate DataLoader(x...; kws...) DataLoader(x; kws...)
end end

View File

@ -1,7 +1,7 @@
# Adapted from Knet's src/data.jl (author: Deniz Yuret) # Adapted from Knet's src/data.jl (author: Deniz Yuret)
struct DataLoader struct DataLoader{D}
data data::D
batchsize::Int batchsize::Int
nobs::Int nobs::Int
partial::Bool partial::Bool
@ -11,21 +11,20 @@ struct DataLoader
end end
""" """
DataLoader(data...; batchsize=1, shuffle=false, partial=true) DataLoader(data; batchsize=1, shuffle=false, partial=true)
An object that iterates over mini-batches of `data`, each mini-batch containing `batchsize` observations An object that iterates over mini-batches of `data`, each mini-batch containing `batchsize` observations
(except possibly the last one). (except possibly the last one).
Takes as input one or more data tensors, e.g. X in unsupervised learning, X and Y in Takes as input a data tensors or a tuple of one or more such tensors.
supervised learning. The last dimension in each tensor is considered to be the observation The last dimension in each tensor is considered to be the observation dimension.
dimension.
If `shuffle=true`, shuffles the observations each time iterations are re-started. If `shuffle=true`, shuffles the observations each time iterations are re-started.
If `partial=false`, drops the last mini-batch if it is smaller than the batchsize. If `partial=false`, drops the last mini-batch if it is smaller than the batchsize.
The original data is preserved as a tuple in the `data` field of the DataLoader. The original data is preserved in the `data` field of the DataLoader.
Example usage: Usage example:
Xtrain = rand(10, 100) Xtrain = rand(10, 100)
train_loader = DataLoader(Xtrain, batchsize=2) train_loader = DataLoader(Xtrain, batchsize=2)
@ -37,9 +36,16 @@ Example usage:
train_loader.data # original dataset train_loader.data # original dataset
# similar, but yielding tuples
train_loader = DataLoader((Xtrain,), batchsize=2)
for (x,) in train_loader
@assert size(x) == (10, 2)
...
end
Xtrain = rand(10, 100) Xtrain = rand(10, 100)
Ytrain = rand(100) Ytrain = rand(100)
train_loader = DataLoader(Xtrain, Ytrain, batchsize=2, shuffle=true) train_loader = DataLoader((Xtrain, Ytrain), batchsize=2, shuffle=true)
for epoch in 1:100 for epoch in 1:100
for (x, y) in train_loader for (x, y) in train_loader
@assert size(x) == (10, 2) @assert size(x) == (10, 2)
@ -52,25 +58,18 @@ Example usage:
using IterTools: ncycle using IterTools: ncycle
Flux.train!(loss, ps, ncycle(train_loader, 10), opt) Flux.train!(loss, ps, ncycle(train_loader, 10), opt)
""" """
function DataLoader(data...; batchsize=1, shuffle=false, partial=true) function DataLoader(data; batchsize=1, shuffle=false, partial=true)
length(data) > 0 || throw(ArgumentError("Need at least one data input"))
batchsize > 0 || throw(ArgumentError("Need positive batchsize")) batchsize > 0 || throw(ArgumentError("Need positive batchsize"))
nx = size(data[1])[end] n = _nobs(data)
for i=2:length(data) if n < batchsize
nx != size(data[i])[end] && throw(DimensionMismatch("All data should contain same number of observations")) @warn "Number of observations less than batchsize, decreasing the batchsize to $n"
batchsize = n
end end
if nx < batchsize imax = partial ? n : n - batchsize + 1
@warn "Number of data points less than batchsize, decreasing the batchsize to $nx" DataLoader(data, batchsize, n, partial, imax, [1:n;], shuffle)
batchsize = nx
end
imax = partial ? nx : nx - batchsize + 1
ids = 1:min(nx, batchsize)
DataLoader(data, batchsize, nx, partial, imax, [1:nx;], shuffle)
end end
getdata(x::AbstractArray, ids) = x[(Base.Colon() for _=1:ndims(x)-1)..., ids]
@propagate_inbounds function Base.iterate(d::DataLoader, i=0) # returns data in d.indices[i+1:i+batchsize] @propagate_inbounds function Base.iterate(d::DataLoader, i=0) # returns data in d.indices[i+1:i+batchsize]
i >= d.imax && return nothing i >= d.imax && return nothing
if d.shuffle && i == 0 if d.shuffle && i == 0
@ -78,11 +77,7 @@ getdata(x::AbstractArray, ids) = x[(Base.Colon() for _=1:ndims(x)-1)..., ids]
end end
nexti = min(i + d.batchsize, d.nobs) nexti = min(i + d.batchsize, d.nobs)
ids = d.indices[i+1:nexti] ids = d.indices[i+1:nexti]
if length(d.data) == 1 batch = _getobs(d.data, ids)
batch = getdata(d.data[1], ids)
else
batch = ((getdata(x, ids) for x in d.data)...,)
end
return (batch, nexti) return (batch, nexti)
end end
@ -90,3 +85,22 @@ function Base.length(d::DataLoader)
n = d.nobs / d.batchsize n = d.nobs / d.batchsize
d.partial ? ceil(Int,n) : floor(Int,n) d.partial ? ceil(Int,n) : floor(Int,n)
end end
_nobs(data::AbstractArray) = size(data)[end]
function _nobs(data::Tuple)
length(data) > 0 || throw(ArgumentError("Need at least one data input"))
n = _nobs(data[1])
if !all(x -> _nobs(x) == n, data[2:end])
throw(DimensionMismatch("All data should contain same number of observations"))
end
return n
end
function _getobs(data::AbstractArray{T,N}, i) where {T,N}
getindex(data, ntuple(i->Colon(), N-1)..., i)
end
_getobs(data::Tuple, i) = map(x -> _getobs(x, i), data)
Base.eltype(d::DataLoader{D}) where D = D

View File

@ -4,6 +4,7 @@
d = DataLoader(X, batchsize=2) d = DataLoader(X, batchsize=2)
batches = collect(d) batches = collect(d)
@test eltype(batches) == eltype(d) == typeof(X)
@test length(batches) == 3 @test length(batches) == 3
@test batches[1] == X[:,1:2] @test batches[1] == X[:,1:2]
@test batches[2] == X[:,3:4] @test batches[2] == X[:,3:4]
@ -11,12 +12,21 @@
d = DataLoader(X, batchsize=2, partial=false) d = DataLoader(X, batchsize=2, partial=false)
batches = collect(d) batches = collect(d)
@test eltype(batches) == eltype(d) == typeof(X)
@test length(batches) == 2 @test length(batches) == 2
@test batches[1] == X[:,1:2] @test batches[1] == X[:,1:2]
@test batches[2] == X[:,3:4] @test batches[2] == X[:,3:4]
d = DataLoader(X, Y, batchsize=2) d = DataLoader((X,), batchsize=2, partial=false)
batches = collect(d) batches = collect(d)
@test eltype(batches) == eltype(d) == Tuple{typeof(X)}
@test length(batches) == 2
@test batches[1] == (X[:,1:2],)
@test batches[2] == (X[:,3:4],)
d = DataLoader((X, Y), batchsize=2)
batches = collect(d)
@test eltype(batches) == eltype(d) == Tuple{typeof(X), typeof(Y)}
@test length(batches) == 3 @test length(batches) == 3
@test length(batches[1]) == 2 @test length(batches[1]) == 2
@test length(batches[2]) == 2 @test length(batches[2]) == 2
@ -41,7 +51,7 @@
X = ones(2, 10) X = ones(2, 10)
Y = fill(2, 10) Y = fill(2, 10)
loss(x, y) = sum((y - x'*θ).^2) loss(x, y) = sum((y - x'*θ).^2)
d = DataLoader(X, Y) d = DataLoader((X, Y))
Flux.train!(loss, [θ], ncycle(d, 10), Descent(0.1)) Flux.train!(loss, [θ], ncycle(d, 10), Descent(0.1))
@test norm(θ .- 1) < 1e-10 @test norm(θ .- 1) < 1e-10
end end

View File

@ -2,49 +2,45 @@ using Flux
using Flux.Data using Flux.Data
using Test using Test
using Random, Statistics, LinearAlgebra using Random, Statistics, LinearAlgebra
using Documenter
using IterTools: ncycle using IterTools: ncycle
Random.seed!(0) Random.seed!(0)
@testset "Flux" begin @testset "Utils" begin
@testset "Utils" begin
include("utils.jl") include("utils.jl")
end end
@testset "Onehot" begin @testset "Onehot" begin
include("onehot.jl") include("onehot.jl")
end end
@testset "Optimise" begin @testset "Optimise" begin
include("optimise.jl") include("optimise.jl")
end end
@testset "Data" begin @testset "Data" begin
include("data.jl") include("data.jl")
end end
@testset "Layers" begin @testset "Layers" begin
include("layers/basic.jl") include("layers/basic.jl")
include("layers/normalisation.jl") include("layers/normalisation.jl")
include("layers/stateless.jl") include("layers/stateless.jl")
include("layers/conv.jl") include("layers/conv.jl")
end end
@testset "CUDA" begin @testset "CUDA" begin
if Flux.use_cuda[] if Flux.use_cuda[]
include("cuda/cuda.jl") include("cuda/cuda.jl")
else else
@warn "CUDA unavailable, not testing GPU support" @warn "CUDA unavailable, not testing GPU support"
end end
end end
@static if VERSION >= v"1.4"
using Documenter
@testset "Docs" begin @testset "Docs" begin
if VERSION >= v"1.4"
DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true) DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true)
doctest(Flux) doctest(Flux)
end end
end end
end # testset Flux