remove multi-arg constructor
This commit is contained in:
parent
d77dbc4931
commit
c6ba49e8ea
@ -10,6 +10,7 @@ export CMUDict, cmudict
|
|||||||
|
|
||||||
deps(path...) = joinpath(@__DIR__, "..", "..", "deps", path...)
|
deps(path...) = joinpath(@__DIR__, "..", "..", "deps", path...)
|
||||||
|
|
||||||
|
|
||||||
function download_and_verify(url, path, hash)
|
function download_and_verify(url, path, hash)
|
||||||
tmppath = tempname()
|
tmppath = tempname()
|
||||||
download(url, tmppath)
|
download(url, tmppath)
|
||||||
@ -51,4 +52,7 @@ export Iris
|
|||||||
include("housing.jl")
|
include("housing.jl")
|
||||||
export Housing
|
export Housing
|
||||||
|
|
||||||
|
|
||||||
|
@deprecate DataLoader(x...; kws...) DataLoader(x; kws...)
|
||||||
|
|
||||||
end
|
end
|
||||||
|
@ -11,20 +11,18 @@ struct DataLoader
|
|||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
DataLoader(data...; batchsize=1, shuffle=false, partial=true)
|
DataLoader(data; batchsize=1, shuffle=false, partial=true)
|
||||||
DataLoader(data::Tuple; ...)
|
|
||||||
|
|
||||||
An object that iterates over mini-batches of `data`, each mini-batch containing `batchsize` observations
|
An object that iterates over mini-batches of `data`, each mini-batch containing `batchsize` observations
|
||||||
(except possibly the last one).
|
(except possibly the last one).
|
||||||
|
|
||||||
Takes as input one or more data tensors (e.g. X in unsupervised learning, X and Y in
|
Takes as input a data tensors or a tuple of one or more such tensors.
|
||||||
supervised learning,) or a tuple of such tensors.
|
|
||||||
The last dimension in each tensor is considered to be the observation dimension.
|
The last dimension in each tensor is considered to be the observation dimension.
|
||||||
|
|
||||||
If `shuffle=true`, shuffles the observations each time iterations are re-started.
|
If `shuffle=true`, shuffles the observations each time iterations are re-started.
|
||||||
If `partial=false`, drops the last mini-batch if it is smaller than the batchsize.
|
If `partial=false`, drops the last mini-batch if it is smaller than the batchsize.
|
||||||
|
|
||||||
The original data is preserved as a tuple in the `data` field of the DataLoader.
|
The original data is preserved in the `data` field of the DataLoader.
|
||||||
|
|
||||||
Example usage:
|
Example usage:
|
||||||
|
|
||||||
@ -36,16 +34,9 @@ Example usage:
|
|||||||
...
|
...
|
||||||
end
|
end
|
||||||
|
|
||||||
train_loader = DataLoader(Xtrain, batchsize=2)
|
|
||||||
# iterate over 50 mini-batches of size 2
|
|
||||||
for x in train_loader
|
|
||||||
@assert size(x) == (10, 2)
|
|
||||||
...
|
|
||||||
end
|
|
||||||
|
|
||||||
train_loader.data # original dataset
|
train_loader.data # original dataset
|
||||||
|
|
||||||
# similar but yelding tuples
|
# similar but yielding tuples
|
||||||
train_loader = DataLoader((Xtrain,), batchsize=2)
|
train_loader = DataLoader((Xtrain,), batchsize=2)
|
||||||
for (x,) in train_loader
|
for (x,) in train_loader
|
||||||
@assert size(x) == (10, 2)
|
@assert size(x) == (10, 2)
|
||||||
@ -54,8 +45,6 @@ Example usage:
|
|||||||
|
|
||||||
Xtrain = rand(10, 100)
|
Xtrain = rand(10, 100)
|
||||||
Ytrain = rand(100)
|
Ytrain = rand(100)
|
||||||
train_loader = DataLoader(Xtrain, Ytrain, batchsize=2, shuffle=true)
|
|
||||||
# or equivalently
|
|
||||||
train_loader = DataLoader((Xtrain, Ytrain), batchsize=2, shuffle=true)
|
train_loader = DataLoader((Xtrain, Ytrain), batchsize=2, shuffle=true)
|
||||||
for epoch in 1:100
|
for epoch in 1:100
|
||||||
for (x, y) in train_loader
|
for (x, y) in train_loader
|
||||||
@ -82,8 +71,6 @@ function DataLoader(data; batchsize=1, shuffle=false, partial=true)
|
|||||||
DataLoader(data, batchsize, n, partial, imax, [1:n;], shuffle)
|
DataLoader(data, batchsize, n, partial, imax, [1:n;], shuffle)
|
||||||
end
|
end
|
||||||
|
|
||||||
DataLoader(data...; kws...) = DataLoader(data; kws...)
|
|
||||||
|
|
||||||
@propagate_inbounds function Base.iterate(d::DataLoader, i=0) # returns data in d.indices[i+1:i+batchsize]
|
@propagate_inbounds function Base.iterate(d::DataLoader, i=0) # returns data in d.indices[i+1:i+batchsize]
|
||||||
i >= d.imax && return nothing
|
i >= d.imax && return nothing
|
||||||
if d.shuffle && i == 0
|
if d.shuffle && i == 0
|
||||||
|
10
test/data.jl
10
test/data.jl
@ -15,7 +15,13 @@
|
|||||||
@test batches[1] == X[:,1:2]
|
@test batches[1] == X[:,1:2]
|
||||||
@test batches[2] == X[:,3:4]
|
@test batches[2] == X[:,3:4]
|
||||||
|
|
||||||
d = DataLoader(X, Y, batchsize=2)
|
d = DataLoader((X,), batchsize=2, partial=false)
|
||||||
|
batches = collect(d)
|
||||||
|
@test length(batches) == 2
|
||||||
|
@test batches[1] == (X[:,1:2],)
|
||||||
|
@test batches[2] == (X[:,3:4],)
|
||||||
|
|
||||||
|
d = DataLoader((X, Y), batchsize=2)
|
||||||
batches = collect(d)
|
batches = collect(d)
|
||||||
@test length(batches) == 3
|
@test length(batches) == 3
|
||||||
@test length(batches[1]) == 2
|
@test length(batches[1]) == 2
|
||||||
@ -41,7 +47,7 @@
|
|||||||
X = ones(2, 10)
|
X = ones(2, 10)
|
||||||
Y = fill(2, 10)
|
Y = fill(2, 10)
|
||||||
loss(x, y) = sum((y - x'*θ).^2)
|
loss(x, y) = sum((y - x'*θ).^2)
|
||||||
d = DataLoader(X, Y)
|
d = DataLoader((X, Y))
|
||||||
Flux.train!(loss, [θ], ncycle(d, 10), Descent(0.1))
|
Flux.train!(loss, [θ], ncycle(d, 10), Descent(0.1))
|
||||||
@test norm(θ .- 1) < 1e-10
|
@test norm(θ .- 1) < 1e-10
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user