DataLoader type inference tests

This commit is contained in:
cossio 2020-06-12 12:05:52 +02:00
parent cb34bb848b
commit 1dbaf32810
2 changed files with 7 additions and 2 deletions

View File

@ -105,6 +105,6 @@ function _nobs(data::Union{Tuple, NamedTuple})
end end
_getobs(data::AbstractArray, i) = selectdim(data, ndims(data), i) _getobs(data::AbstractArray, i) = selectdim(data, ndims(data), i)
_getobs(data::Union{Tuple, NamedTuple}, i) = map(x -> _getobs(x, i), data) _getobs(data::Union{Tuple, NamedTuple}, i) = map(Base.Fix2(_getobs, i), data)
Base.eltype(d::DataLoader{D}) where D = D Base.eltype(::DataLoader{D}) where D = D

View File

@ -3,6 +3,7 @@
Y = [1:5;] Y = [1:5;]
d = DataLoader(X, batchsize=2) d = DataLoader(X, batchsize=2)
@inferred first(d)
batches = collect(d) batches = collect(d)
@test eltype(batches) == eltype(d) == typeof(X) @test eltype(batches) == eltype(d) == typeof(X)
@test length(batches) == 3 @test length(batches) == 3
@ -11,6 +12,7 @@
@test batches[3] == X[:,5:5] @test batches[3] == X[:,5:5]
d = DataLoader(X, batchsize=2, partial=false) d = DataLoader(X, batchsize=2, partial=false)
@inferred first(d)
batches = collect(d) batches = collect(d)
@test eltype(batches) == eltype(d) == typeof(X) @test eltype(batches) == eltype(d) == typeof(X)
@test length(batches) == 2 @test length(batches) == 2
@ -18,6 +20,7 @@
@test batches[2] == X[:,3:4] @test batches[2] == X[:,3:4]
d = DataLoader((X,), batchsize=2, partial=false) d = DataLoader((X,), batchsize=2, partial=false)
@inferred first(d)
batches = collect(d) batches = collect(d)
@test eltype(batches) == eltype(d) == Tuple{typeof(X)} @test eltype(batches) == eltype(d) == Tuple{typeof(X)}
@test length(batches) == 2 @test length(batches) == 2
@ -25,6 +28,7 @@
@test batches[2] == (X[:,3:4],) @test batches[2] == (X[:,3:4],)
d = DataLoader((X, Y), batchsize=2) d = DataLoader((X, Y), batchsize=2)
@inferred first(d)
batches = collect(d) batches = collect(d)
@test eltype(batches) == eltype(d) == Tuple{typeof(X), typeof(Y)} @test eltype(batches) == eltype(d) == Tuple{typeof(X), typeof(Y)}
@test length(batches) == 3 @test length(batches) == 3
@ -40,6 +44,7 @@
# test with NamedTuple # test with NamedTuple
d = DataLoader((x=X, y=Y), batchsize=2) d = DataLoader((x=X, y=Y), batchsize=2)
@inferred first(d)
batches = collect(d) batches = collect(d)
@test eltype(batches) == eltype(d) == NamedTuple{(:x, :y), Tuple{typeof(X), typeof(Y)}} @test eltype(batches) == eltype(d) == NamedTuple{(:x, :y), Tuple{typeof(X), typeof(Y)}}
@test length(batches) == 3 @test length(batches) == 3