From 1dbaf328100a00949cab29d65a11a0aeb7786b09 Mon Sep 17 00:00:00 2001 From: cossio Date: Fri, 12 Jun 2020 12:05:52 +0200 Subject: [PATCH] DataLoader type inference tests --- src/data/dataloader.jl | 4 ++-- test/data.jl | 5 +++++ 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/src/data/dataloader.jl b/src/data/dataloader.jl index 7d4c5fef..f0aea9ac 100644 --- a/src/data/dataloader.jl +++ b/src/data/dataloader.jl @@ -105,6 +105,6 @@ function _nobs(data::Union{Tuple, NamedTuple}) end _getobs(data::AbstractArray, i) = selectdim(data, ndims(data), i) -_getobs(data::Union{Tuple, NamedTuple}, i) = map(x -> _getobs(x, i), data) +_getobs(data::Union{Tuple, NamedTuple}, i) = map(Base.Fix2(_getobs, i), data) -Base.eltype(d::DataLoader{D}) where D = D +Base.eltype(::DataLoader{D}) where D = D diff --git a/test/data.jl b/test/data.jl index 40211f60..7f94caef 100644 --- a/test/data.jl +++ b/test/data.jl @@ -3,6 +3,7 @@ Y = [1:5;] d = DataLoader(X, batchsize=2) + @inferred first(d) batches = collect(d) @test eltype(batches) == eltype(d) == typeof(X) @test length(batches) == 3 @@ -11,6 +12,7 @@ @test batches[3] == X[:,5:5] d = DataLoader(X, batchsize=2, partial=false) + @inferred first(d) batches = collect(d) @test eltype(batches) == eltype(d) == typeof(X) @test length(batches) == 2 @@ -18,6 +20,7 @@ @test batches[2] == X[:,3:4] d = DataLoader((X,), batchsize=2, partial=false) + @inferred first(d) batches = collect(d) @test eltype(batches) == eltype(d) == Tuple{typeof(X)} @test length(batches) == 2 @@ -25,6 +28,7 @@ @test batches[2] == (X[:,3:4],) d = DataLoader((X, Y), batchsize=2) + @inferred first(d) batches = collect(d) @test eltype(batches) == eltype(d) == Tuple{typeof(X), typeof(Y)} @test length(batches) == 3 @@ -40,6 +44,7 @@ # test with NamedTuple d = DataLoader((x=X, y=Y), batchsize=2) + @inferred first(d) batches = collect(d) @test eltype(batches) == eltype(d) == NamedTuple{(:x, :y), Tuple{typeof(X), typeof(Y)}} @test length(batches) == 3