DataLoader with NamedTuple

This commit is contained in:
cossio 2020-06-12 01:57:37 +02:00
parent 97406507fd
commit 02ee6ba426
2 changed files with 19 additions and 4 deletions

View File

@ -88,19 +88,19 @@ end
_nobs(data::AbstractArray) = size(data)[end]
function _nobs(data::Tuple)
function _nobs(data::Union{Tuple,NamedTuple})
length(data) > 0 || throw(ArgumentError("Need at least one data input"))
n = _nobs(data[1])
if !all(x -> _nobs(x) == n, data[2:end])
if !all(x -> _nobs(x) == n, Base.tail(data))
throw(DimensionMismatch("All data should contain same number of observations"))
end
return n
end
function _getobs(data::AbstractArray{T,N}, i) where {T,N}
getindex(data, ntuple(i->Colon(), N-1)..., i)
getindex(data, ntuple(i->Colon(), Val(N-1))..., i)
end
_getobs(data::Tuple, i) = map(x -> _getobs(x, i), data)
_getobs(data::Union{Tuple,NamedTuple}, i) = map(x -> _getobs(x, i), data)
Base.eltype(d::DataLoader{D}) where D = D

View File

@ -38,6 +38,21 @@
@test batches[3][1] == X[:,5:5]
@test batches[3][2] == Y[5:5]
# test with NamedTuple
d = DataLoader((x = X, y = Y), batchsize=2)
batches = collect(d)
@test eltype(batches) == eltype(d) == NamedTuple{(:x, :y), Tuple{typeof(X), typeof(Y)}}
@test length(batches) == 3
@test length(batches[1]) == 2
@test length(batches[2]) == 2
@test length(batches[3]) == 2
@test batches[1][1] == batches[1].x == X[:,1:2]
@test batches[1][2] == batches[1].y == Y[1:2]
@test batches[2][1] == batches[2].x == X[:,3:4]
@test batches[2][2] == batches[2].y == Y[3:4]
@test batches[3][1] == batches[3].x == X[:,5:5]
@test batches[3][2] == batches[3].y == Y[5:5]
# test interaction with `train!`
θ = ones(2)
X = zeros(2, 10)