extend dataloader
This commit is contained in:
parent
c444226db5
commit
d77dbc4931
|
@ -12,13 +12,14 @@ end
|
|||
|
||||
"""
|
||||
DataLoader(data...; batchsize=1, shuffle=false, partial=true)
|
||||
DataLoader(data::Tuple; ...)
|
||||
|
||||
An object that iterates over mini-batches of `data`, each mini-batch containing `batchsize` observations
|
||||
(except possibly the last one).
|
||||
|
||||
Takes as input one or more data tensors, e.g. X in unsupervised learning, X and Y in
|
||||
supervised learning. The last dimension in each tensor is considered to be the observation
|
||||
dimension.
|
||||
Takes as input one or more data tensors (e.g. X in unsupervised learning, X and Y in
|
||||
supervised learning,) or a tuple of such tensors.
|
||||
The last dimension in each tensor is considered to be the observation dimension.
|
||||
|
||||
If `shuffle=true`, shuffles the observations each time iterations are re-started.
|
||||
If `partial=false`, drops the last mini-batch if it is smaller than the batchsize.
|
||||
|
@ -35,11 +36,27 @@ Example usage:
|
|||
...
|
||||
end
|
||||
|
||||
train_loader = DataLoader(Xtrain, batchsize=2)
|
||||
# iterate over 50 mini-batches of size 2
|
||||
for x in train_loader
|
||||
@assert size(x) == (10, 2)
|
||||
...
|
||||
end
|
||||
|
||||
train_loader.data # original dataset
|
||||
|
||||
# similar but yelding tuples
|
||||
train_loader = DataLoader((Xtrain,), batchsize=2)
|
||||
for (x,) in train_loader
|
||||
@assert size(x) == (10, 2)
|
||||
...
|
||||
end
|
||||
|
||||
Xtrain = rand(10, 100)
|
||||
Ytrain = rand(100)
|
||||
train_loader = DataLoader(Xtrain, Ytrain, batchsize=2, shuffle=true)
|
||||
# or equivalently
|
||||
train_loader = DataLoader((Xtrain, Ytrain), batchsize=2, shuffle=true)
|
||||
for epoch in 1:100
|
||||
for (x, y) in train_loader
|
||||
@assert size(x) == (10, 2)
|
||||
|
@ -52,24 +69,20 @@ Example usage:
|
|||
using IterTools: ncycle
|
||||
Flux.train!(loss, ps, ncycle(train_loader, 10), opt)
|
||||
"""
|
||||
function DataLoader(data...; batchsize=1, shuffle=false, partial=true)
|
||||
length(data) > 0 || throw(ArgumentError("Need at least one data input"))
|
||||
function DataLoader(data; batchsize=1, shuffle=false, partial=true)
|
||||
batchsize > 0 || throw(ArgumentError("Need positive batchsize"))
|
||||
|
||||
nx = size(data[1])[end]
|
||||
for i=2:length(data)
|
||||
nx != size(data[i])[end] && throw(DimensionMismatch("All data should contain same number of observations"))
|
||||
n = _nobs(data)
|
||||
if n < batchsize
|
||||
@warn "Number of observations less than batchsize, decreasing the batchsize to $n"
|
||||
batchsize = n
|
||||
end
|
||||
if nx < batchsize
|
||||
@warn "Number of data points less than batchsize, decreasing the batchsize to $nx"
|
||||
batchsize = nx
|
||||
end
|
||||
imax = partial ? nx : nx - batchsize + 1
|
||||
ids = 1:min(nx, batchsize)
|
||||
DataLoader(data, batchsize, nx, partial, imax, [1:nx;], shuffle)
|
||||
imax = partial ? n : n - batchsize + 1
|
||||
ids = 1:min(n, batchsize)
|
||||
DataLoader(data, batchsize, n, partial, imax, [1:n;], shuffle)
|
||||
end
|
||||
|
||||
getdata(x::AbstractArray, ids) = x[(Base.Colon() for _=1:ndims(x)-1)..., ids]
|
||||
DataLoader(data...; kws...) = DataLoader(data; kws...)
|
||||
|
||||
@propagate_inbounds function Base.iterate(d::DataLoader, i=0) # returns data in d.indices[i+1:i+batchsize]
|
||||
i >= d.imax && return nothing
|
||||
|
@ -78,11 +91,7 @@ getdata(x::AbstractArray, ids) = x[(Base.Colon() for _=1:ndims(x)-1)..., ids]
|
|||
end
|
||||
nexti = min(i + d.batchsize, d.nobs)
|
||||
ids = d.indices[i+1:nexti]
|
||||
if length(d.data) == 1
|
||||
batch = getdata(d.data[1], ids)
|
||||
else
|
||||
batch = ((getdata(x, ids) for x in d.data)...,)
|
||||
end
|
||||
batch = _getobs(d.data, ids)
|
||||
return (batch, nexti)
|
||||
end
|
||||
|
||||
|
@ -90,3 +99,17 @@ function Base.length(d::DataLoader)
|
|||
n = d.nobs / d.batchsize
|
||||
d.partial ? ceil(Int,n) : floor(Int,n)
|
||||
end
|
||||
|
||||
_nobs(data::AbstractArray) = size(data)[end]
|
||||
|
||||
function _nobs(data::Tuple)
|
||||
length(data) > 0 || throw(ArgumentError("Need at least one data input"))
|
||||
n = _nobs(data[1])
|
||||
if !all(x -> _nobs(x) == n, data[2:end])
|
||||
throw(DimensionMismatch("All data should contain same number of observations"))
|
||||
end
|
||||
return n
|
||||
end
|
||||
|
||||
_getobs(data::AbstractArray, i) = data[(Base.Colon() for _=1:ndims(data)-1)..., i]
|
||||
_getobs(data::Tuple, i) = ((_getobs(x, i) for x in data)...,)
|
||||
|
|
Loading…
Reference in New Issue