Apply suggestions from code review

accept suggested changes

Co-authored-by: Carlo Lucibello <carlo.lucibello@gmail.com>
This commit is contained in:
cossio 2020-06-12 11:45:42 +02:00 committed by cossio
parent 909a55ac10
commit 75692161a7
2 changed files with 5 additions and 5 deletions

View File

@ -16,7 +16,7 @@ end
An object that iterates over mini-batches of `data`, each mini-batch containing `batchsize` observations An object that iterates over mini-batches of `data`, each mini-batch containing `batchsize` observations
(except possibly the last one). (except possibly the last one).
Takes as input a data tensors or a tuple (or `NamedTuple`) of one or more such tensors. Takes as input a single data tensor, or a tuple (or a named tuple) of tensors.
The last dimension in each tensor is considered to be the observation dimension. The last dimension in each tensor is considered to be the observation dimension.
If `shuffle=true`, shuffles the observations each time iterations are re-started. If `shuffle=true`, shuffles the observations each time iterations are re-started.
@ -59,7 +59,7 @@ Usage example:
Flux.train!(loss, ps, ncycle(train_loader, 10), opt) Flux.train!(loss, ps, ncycle(train_loader, 10), opt)
# can use NamedTuple to name tensors # can use NamedTuple to name tensors
train_loader = DataLoader((images = Xtrain, labels = Ytrain), batchsize=2, shuffle=true) train_loader = DataLoader((images=Xtrain, labels=Ytrain), batchsize=2, shuffle=true)
for datum in train_loader for datum in train_loader
@assert size(datum.images) == (10, 2) @assert size(datum.images) == (10, 2)
@assert size(datum.labels) == (2,) @assert size(datum.labels) == (2,)
@ -95,7 +95,7 @@ end
_nobs(data::AbstractArray) = size(data)[end] _nobs(data::AbstractArray) = size(data)[end]
function _nobs(data::Union{Tuple,NamedTuple}) function _nobs(data::Union{Tuple, NamedTuple})
length(data) > 0 || throw(ArgumentError("Need at least one data input")) length(data) > 0 || throw(ArgumentError("Need at least one data input"))
n = _nobs(data[1]) n = _nobs(data[1])
if !all(x -> _nobs(x) == n, Base.tail(data)) if !all(x -> _nobs(x) == n, Base.tail(data))
@ -108,6 +108,6 @@ function _getobs(data::AbstractArray{T,N}, i) where {T,N}
getindex(data, ntuple(i->Colon(), Val(N-1))..., i) getindex(data, ntuple(i->Colon(), Val(N-1))..., i)
end end
_getobs(data::Union{Tuple,NamedTuple}, i) = map(x -> _getobs(x, i), data) _getobs(data::Union{Tuple, NamedTuple}, i) = map(x -> _getobs(x, i), data)
Base.eltype(d::DataLoader{D}) where D = D Base.eltype(d::DataLoader{D}) where D = D

View File

@ -39,7 +39,7 @@
@test batches[3][2] == Y[5:5] @test batches[3][2] == Y[5:5]
# test with NamedTuple # test with NamedTuple
d = DataLoader((x = X, y = Y), batchsize=2) d = DataLoader((x=X, y=Y), batchsize=2)
batches = collect(d) batches = collect(d)
@test eltype(batches) == eltype(d) == NamedTuple{(:x, :y), Tuple{typeof(X), typeof(Y)}} @test eltype(batches) == eltype(d) == NamedTuple{(:x, :y), Tuple{typeof(X), typeof(Y)}}
@test length(batches) == 3 @test length(batches) == 3