diff --git a/NEWS.md b/NEWS.md index def86ff2..587d909f 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,6 @@ # v0.11 -* Change to `DataLoader`'s constructor [https://github.com/FluxML/Flux.jl/pull/1152] +* Change to `DataLoader`'s constructor [https://github.com/FluxML/Flux.jl/pull/1152] +* Use `DataLoader` with `NamedTuple`s, so that tensors can be accessed by name [https://github.com/FluxML/Flux.jl/pull/1221]. * Error if Dense layers weights and biases are not arrays [https://github.com/FluxML/Flux.jl/pull/1218]. # v0.10.5 diff --git a/src/data/dataloader.jl b/src/data/dataloader.jl index 2db6f6e5..a60313a9 100644 --- a/src/data/dataloader.jl +++ b/src/data/dataloader.jl @@ -16,8 +16,8 @@ end An object that iterates over mini-batches of `data`, each mini-batch containing `batchsize` observations (except possibly the last one). -Takes as input a data tensors or a tuple of one or more such tensors. -The last dimension in each tensor is considered to be the observation dimension. +Takes as input a data tensors or a tuple (or `NamedTuple`) of one or more such tensors. +The last dimension in each tensor is considered to be the observation dimension. If `shuffle=true`, shuffles the observations each time iterations are re-started. If `partial=false`, drops the last mini-batch if it is smaller than the batchsize. @@ -57,6 +57,13 @@ Usage example: # train for 10 epochs using IterTools: ncycle Flux.train!(loss, ps, ncycle(train_loader, 10), opt) + + # can use NamedTuple to name tensors + train_loader = DataLoader((images = Xtrain, labels = Ytrain), batchsize=2, shuffle=true) + for datum in train_loader + @assert size(datum.images) == (10, 2) + @assert size(datum.labels) == (2,) + end """ function DataLoader(data; batchsize=1, shuffle=false, partial=true) batchsize > 0 || throw(ArgumentError("Need positive batchsize"))