extend dataloader

This commit is contained in:
CarloLucibello 2020-04-29 10:18:16 +02:00
parent 792a1c54f8
commit a643cb6758
7 changed files with 92 additions and 70 deletions

View File

@ -384,4 +384,4 @@ version = "0.4.20"
deps = ["MacroTools"]
git-tree-sha1 = "b3b4882cc9accf6731a08cc39543fbc6b669dca8"
uuid = "700de1a5-db45-46bc-99cf-38207098b444"
version = "0.2.0"
version = "0.2.0"

View File

@ -1,6 +1,6 @@
name = "Flux"
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
version = "0.10.5"
version = "0.11.0"
[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"

View File

@ -51,4 +51,6 @@ export Iris
include("housing.jl")
export Housing
@deprecate DataLoader(x...; kws...) DataLoader(x; kws...)
end

View File

@ -1,7 +1,7 @@
# Adapted from Knet's src/data.jl (author: Deniz Yuret)
struct DataLoader
data
struct DataLoader{D}
data::D
batchsize::Int
nobs::Int
partial::Bool
@ -11,21 +11,20 @@ struct DataLoader
end
"""
DataLoader(data...; batchsize=1, shuffle=false, partial=true)
DataLoader(data; batchsize=1, shuffle=false, partial=true)
An object that iterates over mini-batches of `data`, each mini-batch containing `batchsize` observations
(except possibly the last one).
Takes as input one or more data tensors, e.g. X in unsupervised learning, X and Y in
supervised learning. The last dimension in each tensor is considered to be the observation
dimension.
Takes as input a data tensors or a tuple of one or more such tensors.
The last dimension in each tensor is considered to be the observation dimension.
If `shuffle=true`, shuffles the observations each time iterations are re-started.
If `partial=false`, drops the last mini-batch if it is smaller than the batchsize.
The original data is preserved as a tuple in the `data` field of the DataLoader.
The original data is preserved in the `data` field of the DataLoader.
Example usage:
Usage example:
Xtrain = rand(10, 100)
train_loader = DataLoader(Xtrain, batchsize=2)
@ -37,9 +36,16 @@ Example usage:
train_loader.data # original dataset
# similar, but yielding tuples
train_loader = DataLoader((Xtrain,), batchsize=2)
for (x,) in train_loader
@assert size(x) == (10, 2)
...
end
Xtrain = rand(10, 100)
Ytrain = rand(100)
train_loader = DataLoader(Xtrain, Ytrain, batchsize=2, shuffle=true)
train_loader = DataLoader((Xtrain, Ytrain), batchsize=2, shuffle=true)
for epoch in 1:100
for (x, y) in train_loader
@assert size(x) == (10, 2)
@ -52,25 +58,18 @@ Example usage:
using IterTools: ncycle
Flux.train!(loss, ps, ncycle(train_loader, 10), opt)
"""
function DataLoader(data...; batchsize=1, shuffle=false, partial=true)
length(data) > 0 || throw(ArgumentError("Need at least one data input"))
function DataLoader(data; batchsize=1, shuffle=false, partial=true)
batchsize > 0 || throw(ArgumentError("Need positive batchsize"))
nx = size(data[1])[end]
for i=2:length(data)
nx != size(data[i])[end] && throw(DimensionMismatch("All data should contain same number of observations"))
n = _nobs(data)
if n < batchsize
@warn "Number of observations less than batchsize, decreasing the batchsize to $n"
batchsize = n
end
if nx < batchsize
@warn "Number of data points less than batchsize, decreasing the batchsize to $nx"
batchsize = nx
end
imax = partial ? nx : nx - batchsize + 1
ids = 1:min(nx, batchsize)
DataLoader(data, batchsize, nx, partial, imax, [1:nx;], shuffle)
imax = partial ? n : n - batchsize + 1
DataLoader(data, batchsize, n, partial, imax, [1:n;], shuffle)
end
getdata(x::AbstractArray, ids) = x[(Base.Colon() for _=1:ndims(x)-1)..., ids]
@propagate_inbounds function Base.iterate(d::DataLoader, i=0) # returns data in d.indices[i+1:i+batchsize]
i >= d.imax && return nothing
if d.shuffle && i == 0
@ -78,11 +77,7 @@ getdata(x::AbstractArray, ids) = x[(Base.Colon() for _=1:ndims(x)-1)..., ids]
end
nexti = min(i + d.batchsize, d.nobs)
ids = d.indices[i+1:nexti]
if length(d.data) == 1
batch = getdata(d.data[1], ids)
else
batch = ((getdata(x, ids) for x in d.data)...,)
end
batch = _getobs(d.data, ids)
return (batch, nexti)
end
@ -90,3 +85,22 @@ function Base.length(d::DataLoader)
n = d.nobs / d.batchsize
d.partial ? ceil(Int,n) : floor(Int,n)
end
_nobs(data::AbstractArray) = size(data)[end]
function _nobs(data::Tuple)
length(data) > 0 || throw(ArgumentError("Need at least one data input"))
n = _nobs(data[1])
if !all(x -> _nobs(x) == n, data[2:end])
throw(DimensionMismatch("All data should contain same number of observations"))
end
return n
end
function _getobs(data::AbstractArray{T,N}, i) where {T,N}
getindex(data, ntuple(i->Colon(), N-1)..., i)
end
_getobs(data::Tuple, i) = map(x -> _getobs(x, i), data)
Base.eltype(d::DataLoader{D}) where D = D

View File

@ -121,4 +121,4 @@ macro epochs(n, ex)
@info "Epoch $i"
$(esc(ex))
end)
end
end

View File

@ -4,6 +4,7 @@
d = DataLoader(X, batchsize=2)
batches = collect(d)
@test eltype(batches) == eltype(d) == typeof(X)
@test length(batches) == 3
@test batches[1] == X[:,1:2]
@test batches[2] == X[:,3:4]
@ -11,12 +12,21 @@
d = DataLoader(X, batchsize=2, partial=false)
batches = collect(d)
@test eltype(batches) == eltype(d) == typeof(X)
@test length(batches) == 2
@test batches[1] == X[:,1:2]
@test batches[2] == X[:,3:4]
d = DataLoader(X, Y, batchsize=2)
d = DataLoader((X,), batchsize=2, partial=false)
batches = collect(d)
@test eltype(batches) == eltype(d) == Tuple{typeof(X)}
@test length(batches) == 2
@test batches[1] == (X[:,1:2],)
@test batches[2] == (X[:,3:4],)
d = DataLoader((X, Y), batchsize=2)
batches = collect(d)
@test eltype(batches) == eltype(d) == Tuple{typeof(X), typeof(Y)}
@test length(batches) == 3
@test length(batches[1]) == 2
@test length(batches[2]) == 2
@ -41,7 +51,7 @@
X = ones(2, 10)
Y = fill(2, 10)
loss(x, y) = sum((y - x'*θ).^2)
d = DataLoader(X, Y)
d = DataLoader((X, Y))
Flux.train!(loss, [θ], ncycle(d, 10), Descent(0.1))
@test norm(θ .- 1) < 1e-10
end

View File

@ -2,49 +2,45 @@ using Flux
using Flux.Data
using Test
using Random, Statistics, LinearAlgebra
using Documenter
using IterTools: ncycle
Random.seed!(0)
@testset "Flux" begin
@testset "Utils" begin
include("utils.jl")
end
@testset "Utils" begin
include("utils.jl")
end
@testset "Onehot" begin
include("onehot.jl")
end
@testset "Optimise" begin
include("optimise.jl")
end
@testset "Data" begin
include("data.jl")
end
@testset "Layers" begin
include("layers/basic.jl")
include("layers/normalisation.jl")
include("layers/stateless.jl")
include("layers/conv.jl")
end
@testset "CUDA" begin
if Flux.use_cuda[]
include("cuda/cuda.jl")
else
@warn "CUDA unavailable, not testing GPU support"
end
@testset "Onehot" begin
include("onehot.jl")
end
@testset "Optimise" begin
include("optimise.jl")
end
@testset "Data" begin
include("data.jl")
end
@testset "Layers" begin
include("layers/basic.jl")
include("layers/normalisation.jl")
include("layers/stateless.jl")
include("layers/conv.jl")
end
@testset "CUDA" begin
if Flux.use_cuda[]
include("cuda/cuda.jl")
else
@warn "CUDA unavailable, not testing GPU support"
end
end
@static if VERSION >= v"1.4"
using Documenter
@testset "Docs" begin
if VERSION >= v"1.4"
DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true)
doctest(Flux)
end
DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true)
doctest(Flux)
end
end # testset Flux
end