Apply suggestions from code review
accept suggested changes Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com>
This commit is contained in:
parent
909a55ac10
commit
75692161a7
|
@ -16,7 +16,7 @@ end
|
||||||
An object that iterates over mini-batches of `data`, each mini-batch containing `batchsize` observations
|
An object that iterates over mini-batches of `data`, each mini-batch containing `batchsize` observations
|
||||||
(except possibly the last one).
|
(except possibly the last one).
|
||||||
|
|
||||||
Takes as input a data tensors or a tuple (or `NamedTuple`) of one or more such tensors.
|
Takes as input a single data tensor, or a tuple (or a named tuple) of tensors.
|
||||||
The last dimension in each tensor is considered to be the observation dimension.
|
The last dimension in each tensor is considered to be the observation dimension.
|
||||||
|
|
||||||
If `shuffle=true`, shuffles the observations each time iterations are re-started.
|
If `shuffle=true`, shuffles the observations each time iterations are re-started.
|
||||||
|
@ -59,7 +59,7 @@ Usage example:
|
||||||
Flux.train!(loss, ps, ncycle(train_loader, 10), opt)
|
Flux.train!(loss, ps, ncycle(train_loader, 10), opt)
|
||||||
|
|
||||||
# can use NamedTuple to name tensors
|
# can use NamedTuple to name tensors
|
||||||
train_loader = DataLoader((images = Xtrain, labels = Ytrain), batchsize=2, shuffle=true)
|
train_loader = DataLoader((images=Xtrain, labels=Ytrain), batchsize=2, shuffle=true)
|
||||||
for datum in train_loader
|
for datum in train_loader
|
||||||
@assert size(datum.images) == (10, 2)
|
@assert size(datum.images) == (10, 2)
|
||||||
@assert size(datum.labels) == (2,)
|
@assert size(datum.labels) == (2,)
|
||||||
|
@ -95,7 +95,7 @@ end
|
||||||
|
|
||||||
_nobs(data::AbstractArray) = size(data)[end]
|
_nobs(data::AbstractArray) = size(data)[end]
|
||||||
|
|
||||||
function _nobs(data::Union{Tuple,NamedTuple})
|
function _nobs(data::Union{Tuple, NamedTuple})
|
||||||
length(data) > 0 || throw(ArgumentError("Need at least one data input"))
|
length(data) > 0 || throw(ArgumentError("Need at least one data input"))
|
||||||
n = _nobs(data[1])
|
n = _nobs(data[1])
|
||||||
if !all(x -> _nobs(x) == n, Base.tail(data))
|
if !all(x -> _nobs(x) == n, Base.tail(data))
|
||||||
|
@ -108,6 +108,6 @@ function _getobs(data::AbstractArray{T,N}, i) where {T,N}
|
||||||
getindex(data, ntuple(i->Colon(), Val(N-1))..., i)
|
getindex(data, ntuple(i->Colon(), Val(N-1))..., i)
|
||||||
end
|
end
|
||||||
|
|
||||||
_getobs(data::Union{Tuple,NamedTuple}, i) = map(x -> _getobs(x, i), data)
|
_getobs(data::Union{Tuple, NamedTuple}, i) = map(x -> _getobs(x, i), data)
|
||||||
|
|
||||||
Base.eltype(d::DataLoader{D}) where D = D
|
Base.eltype(d::DataLoader{D}) where D = D
|
||||||
|
|
|
@ -39,7 +39,7 @@
|
||||||
@test batches[3][2] == Y[5:5]
|
@test batches[3][2] == Y[5:5]
|
||||||
|
|
||||||
# test with NamedTuple
|
# test with NamedTuple
|
||||||
d = DataLoader((x = X, y = Y), batchsize=2)
|
d = DataLoader((x=X, y=Y), batchsize=2)
|
||||||
batches = collect(d)
|
batches = collect(d)
|
||||||
@test eltype(batches) == eltype(d) == NamedTuple{(:x, :y), Tuple{typeof(X), typeof(Y)}}
|
@test eltype(batches) == eltype(d) == NamedTuple{(:x, :y), Tuple{typeof(X), typeof(Y)}}
|
||||||
@test length(batches) == 3
|
@test length(batches) == 3
|
||||||
|
|
Loading…
Reference in New Issue