DataLoader type inference tests

This commit is contained in:
cossio 2020-06-12 12:05:52 +02:00
parent cb34bb848b
commit 1dbaf32810
2 changed files with 7 additions and 2 deletions

View File

@ -105,6 +105,6 @@ function _nobs(data::Union{Tuple, NamedTuple})
end
_getobs(data::AbstractArray, i) = selectdim(data, ndims(data), i)
_getobs(data::Union{Tuple, NamedTuple}, i) = map(x -> _getobs(x, i), data)
_getobs(data::Union{Tuple, NamedTuple}, i) = map(Base.Fix2(_getobs, i), data)
Base.eltype(d::DataLoader{D}) where D = D
Base.eltype(::DataLoader{D}) where D = D

View File

@ -3,6 +3,7 @@
Y = [1:5;]
d = DataLoader(X, batchsize=2)
@inferred first(d)
batches = collect(d)
@test eltype(batches) == eltype(d) == typeof(X)
@test length(batches) == 3
@ -11,6 +12,7 @@
@test batches[3] == X[:,5:5]
d = DataLoader(X, batchsize=2, partial=false)
@inferred first(d)
batches = collect(d)
@test eltype(batches) == eltype(d) == typeof(X)
@test length(batches) == 2
@ -18,6 +20,7 @@
@test batches[2] == X[:,3:4]
d = DataLoader((X,), batchsize=2, partial=false)
@inferred first(d)
batches = collect(d)
@test eltype(batches) == eltype(d) == Tuple{typeof(X)}
@test length(batches) == 2
@ -25,6 +28,7 @@
@test batches[2] == (X[:,3:4],)
d = DataLoader((X, Y), batchsize=2)
@inferred first(d)
batches = collect(d)
@test eltype(batches) == eltype(d) == Tuple{typeof(X), typeof(Y)}
@test length(batches) == 3
@ -40,6 +44,7 @@
# test with NamedTuple
d = DataLoader((x=X, y=Y), batchsize=2)
@inferred first(d)
batches = collect(d)
@test eltype(batches) == eltype(d) == NamedTuple{(:x, :y), Tuple{typeof(X), typeof(Y)}}
@test length(batches) == 3