DataLoader type inference tests
This commit is contained in:
parent
cb34bb848b
commit
1dbaf32810
|
@ -105,6 +105,6 @@ function _nobs(data::Union{Tuple, NamedTuple})
|
|||
end
|
||||
|
||||
_getobs(data::AbstractArray, i) = selectdim(data, ndims(data), i)
|
||||
_getobs(data::Union{Tuple, NamedTuple}, i) = map(x -> _getobs(x, i), data)
|
||||
_getobs(data::Union{Tuple, NamedTuple}, i) = map(Base.Fix2(_getobs, i), data)
|
||||
|
||||
Base.eltype(d::DataLoader{D}) where D = D
|
||||
Base.eltype(::DataLoader{D}) where D = D
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
Y = [1:5;]
|
||||
|
||||
d = DataLoader(X, batchsize=2)
|
||||
@inferred first(d)
|
||||
batches = collect(d)
|
||||
@test eltype(batches) == eltype(d) == typeof(X)
|
||||
@test length(batches) == 3
|
||||
|
@ -11,6 +12,7 @@
|
|||
@test batches[3] == X[:,5:5]
|
||||
|
||||
d = DataLoader(X, batchsize=2, partial=false)
|
||||
@inferred first(d)
|
||||
batches = collect(d)
|
||||
@test eltype(batches) == eltype(d) == typeof(X)
|
||||
@test length(batches) == 2
|
||||
|
@ -18,6 +20,7 @@
|
|||
@test batches[2] == X[:,3:4]
|
||||
|
||||
d = DataLoader((X,), batchsize=2, partial=false)
|
||||
@inferred first(d)
|
||||
batches = collect(d)
|
||||
@test eltype(batches) == eltype(d) == Tuple{typeof(X)}
|
||||
@test length(batches) == 2
|
||||
|
@ -25,6 +28,7 @@
|
|||
@test batches[2] == (X[:,3:4],)
|
||||
|
||||
d = DataLoader((X, Y), batchsize=2)
|
||||
@inferred first(d)
|
||||
batches = collect(d)
|
||||
@test eltype(batches) == eltype(d) == Tuple{typeof(X), typeof(Y)}
|
||||
@test length(batches) == 3
|
||||
|
@ -40,6 +44,7 @@
|
|||
|
||||
# test with NamedTuple
|
||||
d = DataLoader((x=X, y=Y), batchsize=2)
|
||||
@inferred first(d)
|
||||
batches = collect(d)
|
||||
@test eltype(batches) == eltype(d) == NamedTuple{(:x, :y), Tuple{typeof(X), typeof(Y)}}
|
||||
@test length(batches) == 3
|
||||
|
|
Loading…
Reference in New Issue