news and docs

This commit is contained in:
cossio 2020-06-12 02:09:17 +02:00
parent 02ee6ba426
commit 909a55ac10
2 changed files with 11 additions and 3 deletions

View File

@ -1,5 +1,6 @@
# v0.11 # v0.11
* Change to `DataLoader`'s constructor [https://github.com/FluxML/Flux.jl/pull/1152] * Change to `DataLoader`'s constructor [https://github.com/FluxML/Flux.jl/pull/1152]
* Use `DataLoader` with `NamedTuple`s, so that tensors can be accessed by name [https://github.com/FluxML/Flux.jl/pull/1221].
* Error if Dense layers weights and biases are not arrays [https://github.com/FluxML/Flux.jl/pull/1218]. * Error if Dense layers weights and biases are not arrays [https://github.com/FluxML/Flux.jl/pull/1218].
# v0.10.5 # v0.10.5

View File

@ -16,8 +16,8 @@ end
An object that iterates over mini-batches of `data`, each mini-batch containing `batchsize` observations An object that iterates over mini-batches of `data`, each mini-batch containing `batchsize` observations
(except possibly the last one). (except possibly the last one).
Takes as input a data tensors or a tuple of one or more such tensors. Takes as input a data tensors or a tuple (or `NamedTuple`) of one or more such tensors.
The last dimension in each tensor is considered to be the observation dimension. The last dimension in each tensor is considered to be the observation dimension.
If `shuffle=true`, shuffles the observations each time iterations are re-started. If `shuffle=true`, shuffles the observations each time iterations are re-started.
If `partial=false`, drops the last mini-batch if it is smaller than the batchsize. If `partial=false`, drops the last mini-batch if it is smaller than the batchsize.
@ -57,6 +57,13 @@ Usage example:
# train for 10 epochs # train for 10 epochs
using IterTools: ncycle using IterTools: ncycle
Flux.train!(loss, ps, ncycle(train_loader, 10), opt) Flux.train!(loss, ps, ncycle(train_loader, 10), opt)
# can use NamedTuple to name tensors
train_loader = DataLoader((images = Xtrain, labels = Ytrain), batchsize=2, shuffle=true)
for datum in train_loader
@assert size(datum.images) == (10, 2)
@assert size(datum.labels) == (2,)
end
""" """
function DataLoader(data; batchsize=1, shuffle=false, partial=true) function DataLoader(data; batchsize=1, shuffle=false, partial=true)
batchsize > 0 || throw(ArgumentError("Need positive batchsize")) batchsize > 0 || throw(ArgumentError("Need positive batchsize"))