remove multi-arg constructor

This commit is contained in:
CarloLucibello 2020-04-29 10:31:43 +02:00
parent d77dbc4931
commit c6ba49e8ea
4 changed files with 18 additions and 21 deletions

View File

@ -10,6 +10,7 @@ export CMUDict, cmudict
deps(path...) = joinpath(@__DIR__, "..", "..", "deps", path...) deps(path...) = joinpath(@__DIR__, "..", "..", "deps", path...)
function download_and_verify(url, path, hash) function download_and_verify(url, path, hash)
tmppath = tempname() tmppath = tempname()
download(url, tmppath) download(url, tmppath)
@ -51,4 +52,7 @@ export Iris
include("housing.jl") include("housing.jl")
export Housing export Housing
@deprecate DataLoader(x...; kws...) DataLoader(x; kws...)
end end

View File

@ -11,20 +11,18 @@ struct DataLoader
end end
""" """
DataLoader(data...; batchsize=1, shuffle=false, partial=true) DataLoader(data; batchsize=1, shuffle=false, partial=true)
DataLoader(data::Tuple; ...)
An object that iterates over mini-batches of `data`, each mini-batch containing `batchsize` observations An object that iterates over mini-batches of `data`, each mini-batch containing `batchsize` observations
(except possibly the last one). (except possibly the last one).
Takes as input one or more data tensors (e.g. X in unsupervised learning, X and Y in Takes as input a data tensors or a tuple of one or more such tensors.
supervised learning,) or a tuple of such tensors.
The last dimension in each tensor is considered to be the observation dimension. The last dimension in each tensor is considered to be the observation dimension.
If `shuffle=true`, shuffles the observations each time iterations are re-started. If `shuffle=true`, shuffles the observations each time iterations are re-started.
If `partial=false`, drops the last mini-batch if it is smaller than the batchsize. If `partial=false`, drops the last mini-batch if it is smaller than the batchsize.
The original data is preserved as a tuple in the `data` field of the DataLoader. The original data is preserved in the `data` field of the DataLoader.
Example usage: Example usage:
@ -36,16 +34,9 @@ Example usage:
... ...
end end
train_loader = DataLoader(Xtrain, batchsize=2)
# iterate over 50 mini-batches of size 2
for x in train_loader
@assert size(x) == (10, 2)
...
end
train_loader.data # original dataset train_loader.data # original dataset
# similar but yelding tuples # similar but yielding tuples
train_loader = DataLoader((Xtrain,), batchsize=2) train_loader = DataLoader((Xtrain,), batchsize=2)
for (x,) in train_loader for (x,) in train_loader
@assert size(x) == (10, 2) @assert size(x) == (10, 2)
@ -54,8 +45,6 @@ Example usage:
Xtrain = rand(10, 100) Xtrain = rand(10, 100)
Ytrain = rand(100) Ytrain = rand(100)
train_loader = DataLoader(Xtrain, Ytrain, batchsize=2, shuffle=true)
# or equivalently
train_loader = DataLoader((Xtrain, Ytrain), batchsize=2, shuffle=true) train_loader = DataLoader((Xtrain, Ytrain), batchsize=2, shuffle=true)
for epoch in 1:100 for epoch in 1:100
for (x, y) in train_loader for (x, y) in train_loader
@ -82,8 +71,6 @@ function DataLoader(data; batchsize=1, shuffle=false, partial=true)
DataLoader(data, batchsize, n, partial, imax, [1:n;], shuffle) DataLoader(data, batchsize, n, partial, imax, [1:n;], shuffle)
end end
DataLoader(data...; kws...) = DataLoader(data; kws...)
@propagate_inbounds function Base.iterate(d::DataLoader, i=0) # returns data in d.indices[i+1:i+batchsize] @propagate_inbounds function Base.iterate(d::DataLoader, i=0) # returns data in d.indices[i+1:i+batchsize]
i >= d.imax && return nothing i >= d.imax && return nothing
if d.shuffle && i == 0 if d.shuffle && i == 0

View File

@ -15,7 +15,13 @@
@test batches[1] == X[:,1:2] @test batches[1] == X[:,1:2]
@test batches[2] == X[:,3:4] @test batches[2] == X[:,3:4]
d = DataLoader(X, Y, batchsize=2) d = DataLoader((X,), batchsize=2, partial=false)
batches = collect(d)
@test length(batches) == 2
@test batches[1] == (X[:,1:2],)
@test batches[2] == (X[:,3:4],)
d = DataLoader((X, Y), batchsize=2)
batches = collect(d) batches = collect(d)
@test length(batches) == 3 @test length(batches) == 3
@test length(batches[1]) == 2 @test length(batches[1]) == 2
@ -41,7 +47,7 @@
X = ones(2, 10) X = ones(2, 10)
Y = fill(2, 10) Y = fill(2, 10)
loss(x, y) = sum((y - x'*θ).^2) loss(x, y) = sum((y - x'*θ).^2)
d = DataLoader(X, Y) d = DataLoader((X, Y))
Flux.train!(loss, [θ], ncycle(d, 10), Descent(0.1)) Flux.train!(loss, [θ], ncycle(d, 10), Descent(0.1))
@test norm(θ .- 1) < 1e-10 @test norm(θ .- 1) < 1e-10
end end