DataLoader type inference tests
This commit is contained in:
parent
cb34bb848b
commit
1dbaf32810
|
@ -105,6 +105,6 @@ function _nobs(data::Union{Tuple, NamedTuple})
|
||||||
end
|
end
|
||||||
|
|
||||||
_getobs(data::AbstractArray, i) = selectdim(data, ndims(data), i)
|
_getobs(data::AbstractArray, i) = selectdim(data, ndims(data), i)
|
||||||
_getobs(data::Union{Tuple, NamedTuple}, i) = map(x -> _getobs(x, i), data)
|
_getobs(data::Union{Tuple, NamedTuple}, i) = map(Base.Fix2(_getobs, i), data)
|
||||||
|
|
||||||
Base.eltype(d::DataLoader{D}) where D = D
|
Base.eltype(::DataLoader{D}) where D = D
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
Y = [1:5;]
|
Y = [1:5;]
|
||||||
|
|
||||||
d = DataLoader(X, batchsize=2)
|
d = DataLoader(X, batchsize=2)
|
||||||
|
@inferred first(d)
|
||||||
batches = collect(d)
|
batches = collect(d)
|
||||||
@test eltype(batches) == eltype(d) == typeof(X)
|
@test eltype(batches) == eltype(d) == typeof(X)
|
||||||
@test length(batches) == 3
|
@test length(batches) == 3
|
||||||
|
@ -11,6 +12,7 @@
|
||||||
@test batches[3] == X[:,5:5]
|
@test batches[3] == X[:,5:5]
|
||||||
|
|
||||||
d = DataLoader(X, batchsize=2, partial=false)
|
d = DataLoader(X, batchsize=2, partial=false)
|
||||||
|
@inferred first(d)
|
||||||
batches = collect(d)
|
batches = collect(d)
|
||||||
@test eltype(batches) == eltype(d) == typeof(X)
|
@test eltype(batches) == eltype(d) == typeof(X)
|
||||||
@test length(batches) == 2
|
@test length(batches) == 2
|
||||||
|
@ -18,6 +20,7 @@
|
||||||
@test batches[2] == X[:,3:4]
|
@test batches[2] == X[:,3:4]
|
||||||
|
|
||||||
d = DataLoader((X,), batchsize=2, partial=false)
|
d = DataLoader((X,), batchsize=2, partial=false)
|
||||||
|
@inferred first(d)
|
||||||
batches = collect(d)
|
batches = collect(d)
|
||||||
@test eltype(batches) == eltype(d) == Tuple{typeof(X)}
|
@test eltype(batches) == eltype(d) == Tuple{typeof(X)}
|
||||||
@test length(batches) == 2
|
@test length(batches) == 2
|
||||||
|
@ -25,6 +28,7 @@
|
||||||
@test batches[2] == (X[:,3:4],)
|
@test batches[2] == (X[:,3:4],)
|
||||||
|
|
||||||
d = DataLoader((X, Y), batchsize=2)
|
d = DataLoader((X, Y), batchsize=2)
|
||||||
|
@inferred first(d)
|
||||||
batches = collect(d)
|
batches = collect(d)
|
||||||
@test eltype(batches) == eltype(d) == Tuple{typeof(X), typeof(Y)}
|
@test eltype(batches) == eltype(d) == Tuple{typeof(X), typeof(Y)}
|
||||||
@test length(batches) == 3
|
@test length(batches) == 3
|
||||||
|
@ -40,6 +44,7 @@
|
||||||
|
|
||||||
# test with NamedTuple
|
# test with NamedTuple
|
||||||
d = DataLoader((x=X, y=Y), batchsize=2)
|
d = DataLoader((x=X, y=Y), batchsize=2)
|
||||||
|
@inferred first(d)
|
||||||
batches = collect(d)
|
batches = collect(d)
|
||||||
@test eltype(batches) == eltype(d) == NamedTuple{(:x, :y), Tuple{typeof(X), typeof(Y)}}
|
@test eltype(batches) == eltype(d) == NamedTuple{(:x, :y), Tuple{typeof(X), typeof(Y)}}
|
||||||
@test length(batches) == 3
|
@test length(batches) == 3
|
||||||
|
|
Loading…
Reference in New Issue