extend dataloader
This commit is contained in:
parent
792a1c54f8
commit
a643cb6758
|
@ -384,4 +384,4 @@ version = "0.4.20"
|
||||||
deps = ["MacroTools"]
|
deps = ["MacroTools"]
|
||||||
git-tree-sha1 = "b3b4882cc9accf6731a08cc39543fbc6b669dca8"
|
git-tree-sha1 = "b3b4882cc9accf6731a08cc39543fbc6b669dca8"
|
||||||
uuid = "700de1a5-db45-46bc-99cf-38207098b444"
|
uuid = "700de1a5-db45-46bc-99cf-38207098b444"
|
||||||
version = "0.2.0"
|
version = "0.2.0"
|
|
@ -1,6 +1,6 @@
|
||||||
name = "Flux"
|
name = "Flux"
|
||||||
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
|
uuid = "587475ba-b771-5e3f-ad9e-33799f191a9c"
|
||||||
version = "0.10.5"
|
version = "0.11.0"
|
||||||
|
|
||||||
[deps]
|
[deps]
|
||||||
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
|
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
|
||||||
|
|
|
@ -51,4 +51,6 @@ export Iris
|
||||||
include("housing.jl")
|
include("housing.jl")
|
||||||
export Housing
|
export Housing
|
||||||
|
|
||||||
|
@deprecate DataLoader(x...; kws...) DataLoader(x; kws...)
|
||||||
|
|
||||||
end
|
end
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
# Adapted from Knet's src/data.jl (author: Deniz Yuret)
|
# Adapted from Knet's src/data.jl (author: Deniz Yuret)
|
||||||
|
|
||||||
struct DataLoader
|
struct DataLoader{D}
|
||||||
data
|
data::D
|
||||||
batchsize::Int
|
batchsize::Int
|
||||||
nobs::Int
|
nobs::Int
|
||||||
partial::Bool
|
partial::Bool
|
||||||
|
@ -11,21 +11,20 @@ struct DataLoader
|
||||||
end
|
end
|
||||||
|
|
||||||
"""
|
"""
|
||||||
DataLoader(data...; batchsize=1, shuffle=false, partial=true)
|
DataLoader(data; batchsize=1, shuffle=false, partial=true)
|
||||||
|
|
||||||
An object that iterates over mini-batches of `data`, each mini-batch containing `batchsize` observations
|
An object that iterates over mini-batches of `data`, each mini-batch containing `batchsize` observations
|
||||||
(except possibly the last one).
|
(except possibly the last one).
|
||||||
|
|
||||||
Takes as input one or more data tensors, e.g. X in unsupervised learning, X and Y in
|
Takes as input a data tensors or a tuple of one or more such tensors.
|
||||||
supervised learning. The last dimension in each tensor is considered to be the observation
|
The last dimension in each tensor is considered to be the observation dimension.
|
||||||
dimension.
|
|
||||||
|
|
||||||
If `shuffle=true`, shuffles the observations each time iterations are re-started.
|
If `shuffle=true`, shuffles the observations each time iterations are re-started.
|
||||||
If `partial=false`, drops the last mini-batch if it is smaller than the batchsize.
|
If `partial=false`, drops the last mini-batch if it is smaller than the batchsize.
|
||||||
|
|
||||||
The original data is preserved as a tuple in the `data` field of the DataLoader.
|
The original data is preserved in the `data` field of the DataLoader.
|
||||||
|
|
||||||
Example usage:
|
Usage example:
|
||||||
|
|
||||||
Xtrain = rand(10, 100)
|
Xtrain = rand(10, 100)
|
||||||
train_loader = DataLoader(Xtrain, batchsize=2)
|
train_loader = DataLoader(Xtrain, batchsize=2)
|
||||||
|
@ -37,9 +36,16 @@ Example usage:
|
||||||
|
|
||||||
train_loader.data # original dataset
|
train_loader.data # original dataset
|
||||||
|
|
||||||
|
# similar, but yielding tuples
|
||||||
|
train_loader = DataLoader((Xtrain,), batchsize=2)
|
||||||
|
for (x,) in train_loader
|
||||||
|
@assert size(x) == (10, 2)
|
||||||
|
...
|
||||||
|
end
|
||||||
|
|
||||||
Xtrain = rand(10, 100)
|
Xtrain = rand(10, 100)
|
||||||
Ytrain = rand(100)
|
Ytrain = rand(100)
|
||||||
train_loader = DataLoader(Xtrain, Ytrain, batchsize=2, shuffle=true)
|
train_loader = DataLoader((Xtrain, Ytrain), batchsize=2, shuffle=true)
|
||||||
for epoch in 1:100
|
for epoch in 1:100
|
||||||
for (x, y) in train_loader
|
for (x, y) in train_loader
|
||||||
@assert size(x) == (10, 2)
|
@assert size(x) == (10, 2)
|
||||||
|
@ -52,25 +58,18 @@ Example usage:
|
||||||
using IterTools: ncycle
|
using IterTools: ncycle
|
||||||
Flux.train!(loss, ps, ncycle(train_loader, 10), opt)
|
Flux.train!(loss, ps, ncycle(train_loader, 10), opt)
|
||||||
"""
|
"""
|
||||||
function DataLoader(data...; batchsize=1, shuffle=false, partial=true)
|
function DataLoader(data; batchsize=1, shuffle=false, partial=true)
|
||||||
length(data) > 0 || throw(ArgumentError("Need at least one data input"))
|
|
||||||
batchsize > 0 || throw(ArgumentError("Need positive batchsize"))
|
batchsize > 0 || throw(ArgumentError("Need positive batchsize"))
|
||||||
|
|
||||||
nx = size(data[1])[end]
|
n = _nobs(data)
|
||||||
for i=2:length(data)
|
if n < batchsize
|
||||||
nx != size(data[i])[end] && throw(DimensionMismatch("All data should contain same number of observations"))
|
@warn "Number of observations less than batchsize, decreasing the batchsize to $n"
|
||||||
|
batchsize = n
|
||||||
end
|
end
|
||||||
if nx < batchsize
|
imax = partial ? n : n - batchsize + 1
|
||||||
@warn "Number of data points less than batchsize, decreasing the batchsize to $nx"
|
DataLoader(data, batchsize, n, partial, imax, [1:n;], shuffle)
|
||||||
batchsize = nx
|
|
||||||
end
|
|
||||||
imax = partial ? nx : nx - batchsize + 1
|
|
||||||
ids = 1:min(nx, batchsize)
|
|
||||||
DataLoader(data, batchsize, nx, partial, imax, [1:nx;], shuffle)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
getdata(x::AbstractArray, ids) = x[(Base.Colon() for _=1:ndims(x)-1)..., ids]
|
|
||||||
|
|
||||||
@propagate_inbounds function Base.iterate(d::DataLoader, i=0) # returns data in d.indices[i+1:i+batchsize]
|
@propagate_inbounds function Base.iterate(d::DataLoader, i=0) # returns data in d.indices[i+1:i+batchsize]
|
||||||
i >= d.imax && return nothing
|
i >= d.imax && return nothing
|
||||||
if d.shuffle && i == 0
|
if d.shuffle && i == 0
|
||||||
|
@ -78,11 +77,7 @@ getdata(x::AbstractArray, ids) = x[(Base.Colon() for _=1:ndims(x)-1)..., ids]
|
||||||
end
|
end
|
||||||
nexti = min(i + d.batchsize, d.nobs)
|
nexti = min(i + d.batchsize, d.nobs)
|
||||||
ids = d.indices[i+1:nexti]
|
ids = d.indices[i+1:nexti]
|
||||||
if length(d.data) == 1
|
batch = _getobs(d.data, ids)
|
||||||
batch = getdata(d.data[1], ids)
|
|
||||||
else
|
|
||||||
batch = ((getdata(x, ids) for x in d.data)...,)
|
|
||||||
end
|
|
||||||
return (batch, nexti)
|
return (batch, nexti)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -90,3 +85,22 @@ function Base.length(d::DataLoader)
|
||||||
n = d.nobs / d.batchsize
|
n = d.nobs / d.batchsize
|
||||||
d.partial ? ceil(Int,n) : floor(Int,n)
|
d.partial ? ceil(Int,n) : floor(Int,n)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
_nobs(data::AbstractArray) = size(data)[end]
|
||||||
|
|
||||||
|
function _nobs(data::Tuple)
|
||||||
|
length(data) > 0 || throw(ArgumentError("Need at least one data input"))
|
||||||
|
n = _nobs(data[1])
|
||||||
|
if !all(x -> _nobs(x) == n, data[2:end])
|
||||||
|
throw(DimensionMismatch("All data should contain same number of observations"))
|
||||||
|
end
|
||||||
|
return n
|
||||||
|
end
|
||||||
|
|
||||||
|
function _getobs(data::AbstractArray{T,N}, i) where {T,N}
|
||||||
|
getindex(data, ntuple(i->Colon(), N-1)..., i)
|
||||||
|
end
|
||||||
|
|
||||||
|
_getobs(data::Tuple, i) = map(x -> _getobs(x, i), data)
|
||||||
|
|
||||||
|
Base.eltype(d::DataLoader{D}) where D = D
|
||||||
|
|
|
@ -121,4 +121,4 @@ macro epochs(n, ex)
|
||||||
@info "Epoch $i"
|
@info "Epoch $i"
|
||||||
$(esc(ex))
|
$(esc(ex))
|
||||||
end)
|
end)
|
||||||
end
|
end
|
14
test/data.jl
14
test/data.jl
|
@ -4,6 +4,7 @@
|
||||||
|
|
||||||
d = DataLoader(X, batchsize=2)
|
d = DataLoader(X, batchsize=2)
|
||||||
batches = collect(d)
|
batches = collect(d)
|
||||||
|
@test eltype(batches) == eltype(d) == typeof(X)
|
||||||
@test length(batches) == 3
|
@test length(batches) == 3
|
||||||
@test batches[1] == X[:,1:2]
|
@test batches[1] == X[:,1:2]
|
||||||
@test batches[2] == X[:,3:4]
|
@test batches[2] == X[:,3:4]
|
||||||
|
@ -11,12 +12,21 @@
|
||||||
|
|
||||||
d = DataLoader(X, batchsize=2, partial=false)
|
d = DataLoader(X, batchsize=2, partial=false)
|
||||||
batches = collect(d)
|
batches = collect(d)
|
||||||
|
@test eltype(batches) == eltype(d) == typeof(X)
|
||||||
@test length(batches) == 2
|
@test length(batches) == 2
|
||||||
@test batches[1] == X[:,1:2]
|
@test batches[1] == X[:,1:2]
|
||||||
@test batches[2] == X[:,3:4]
|
@test batches[2] == X[:,3:4]
|
||||||
|
|
||||||
d = DataLoader(X, Y, batchsize=2)
|
d = DataLoader((X,), batchsize=2, partial=false)
|
||||||
batches = collect(d)
|
batches = collect(d)
|
||||||
|
@test eltype(batches) == eltype(d) == Tuple{typeof(X)}
|
||||||
|
@test length(batches) == 2
|
||||||
|
@test batches[1] == (X[:,1:2],)
|
||||||
|
@test batches[2] == (X[:,3:4],)
|
||||||
|
|
||||||
|
d = DataLoader((X, Y), batchsize=2)
|
||||||
|
batches = collect(d)
|
||||||
|
@test eltype(batches) == eltype(d) == Tuple{typeof(X), typeof(Y)}
|
||||||
@test length(batches) == 3
|
@test length(batches) == 3
|
||||||
@test length(batches[1]) == 2
|
@test length(batches[1]) == 2
|
||||||
@test length(batches[2]) == 2
|
@test length(batches[2]) == 2
|
||||||
|
@ -41,7 +51,7 @@
|
||||||
X = ones(2, 10)
|
X = ones(2, 10)
|
||||||
Y = fill(2, 10)
|
Y = fill(2, 10)
|
||||||
loss(x, y) = sum((y - x'*θ).^2)
|
loss(x, y) = sum((y - x'*θ).^2)
|
||||||
d = DataLoader(X, Y)
|
d = DataLoader((X, Y))
|
||||||
Flux.train!(loss, [θ], ncycle(d, 10), Descent(0.1))
|
Flux.train!(loss, [θ], ncycle(d, 10), Descent(0.1))
|
||||||
@test norm(θ .- 1) < 1e-10
|
@test norm(θ .- 1) < 1e-10
|
||||||
end
|
end
|
||||||
|
|
|
@ -2,49 +2,45 @@ using Flux
|
||||||
using Flux.Data
|
using Flux.Data
|
||||||
using Test
|
using Test
|
||||||
using Random, Statistics, LinearAlgebra
|
using Random, Statistics, LinearAlgebra
|
||||||
using Documenter
|
|
||||||
using IterTools: ncycle
|
using IterTools: ncycle
|
||||||
|
|
||||||
Random.seed!(0)
|
Random.seed!(0)
|
||||||
|
|
||||||
@testset "Flux" begin
|
@testset "Utils" begin
|
||||||
|
include("utils.jl")
|
||||||
|
end
|
||||||
|
|
||||||
@testset "Utils" begin
|
@testset "Onehot" begin
|
||||||
include("utils.jl")
|
include("onehot.jl")
|
||||||
end
|
end
|
||||||
|
|
||||||
@testset "Onehot" begin
|
@testset "Optimise" begin
|
||||||
include("onehot.jl")
|
include("optimise.jl")
|
||||||
end
|
end
|
||||||
|
|
||||||
@testset "Optimise" begin
|
@testset "Data" begin
|
||||||
include("optimise.jl")
|
include("data.jl")
|
||||||
end
|
end
|
||||||
|
|
||||||
@testset "Data" begin
|
@testset "Layers" begin
|
||||||
include("data.jl")
|
include("layers/basic.jl")
|
||||||
end
|
include("layers/normalisation.jl")
|
||||||
|
include("layers/stateless.jl")
|
||||||
@testset "Layers" begin
|
include("layers/conv.jl")
|
||||||
include("layers/basic.jl")
|
end
|
||||||
include("layers/normalisation.jl")
|
|
||||||
include("layers/stateless.jl")
|
@testset "CUDA" begin
|
||||||
include("layers/conv.jl")
|
if Flux.use_cuda[]
|
||||||
end
|
include("cuda/cuda.jl")
|
||||||
|
else
|
||||||
@testset "CUDA" begin
|
@warn "CUDA unavailable, not testing GPU support"
|
||||||
if Flux.use_cuda[]
|
|
||||||
include("cuda/cuda.jl")
|
|
||||||
else
|
|
||||||
@warn "CUDA unavailable, not testing GPU support"
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
@static if VERSION >= v"1.4"
|
||||||
|
using Documenter
|
||||||
@testset "Docs" begin
|
@testset "Docs" begin
|
||||||
if VERSION >= v"1.4"
|
DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true)
|
||||||
DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true)
|
doctest(Flux)
|
||||||
doctest(Flux)
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
end
|
||||||
end # testset Flux
|
|
||||||
|
|
Loading…
Reference in New Issue