DataLoader with NamedTuple
This commit is contained in:
parent
97406507fd
commit
02ee6ba426
|
@ -88,19 +88,19 @@ end
|
|||
|
||||
_nobs(data::AbstractArray) = size(data)[end]
|
||||
|
||||
function _nobs(data::Tuple)
|
||||
function _nobs(data::Union{Tuple,NamedTuple})
|
||||
length(data) > 0 || throw(ArgumentError("Need at least one data input"))
|
||||
n = _nobs(data[1])
|
||||
if !all(x -> _nobs(x) == n, data[2:end])
|
||||
if !all(x -> _nobs(x) == n, Base.tail(data))
|
||||
throw(DimensionMismatch("All data should contain same number of observations"))
|
||||
end
|
||||
return n
|
||||
end
|
||||
|
||||
function _getobs(data::AbstractArray{T,N}, i) where {T,N}
|
||||
getindex(data, ntuple(i->Colon(), N-1)..., i)
|
||||
getindex(data, ntuple(i->Colon(), Val(N-1))..., i)
|
||||
end
|
||||
|
||||
_getobs(data::Tuple, i) = map(x -> _getobs(x, i), data)
|
||||
_getobs(data::Union{Tuple,NamedTuple}, i) = map(x -> _getobs(x, i), data)
|
||||
|
||||
Base.eltype(d::DataLoader{D}) where D = D
|
||||
|
|
15
test/data.jl
15
test/data.jl
|
@ -38,6 +38,21 @@
|
|||
@test batches[3][1] == X[:,5:5]
|
||||
@test batches[3][2] == Y[5:5]
|
||||
|
||||
# test with NamedTuple
|
||||
d = DataLoader((x = X, y = Y), batchsize=2)
|
||||
batches = collect(d)
|
||||
@test eltype(batches) == eltype(d) == NamedTuple{(:x, :y), Tuple{typeof(X), typeof(Y)}}
|
||||
@test length(batches) == 3
|
||||
@test length(batches[1]) == 2
|
||||
@test length(batches[2]) == 2
|
||||
@test length(batches[3]) == 2
|
||||
@test batches[1][1] == batches[1].x == X[:,1:2]
|
||||
@test batches[1][2] == batches[1].y == Y[1:2]
|
||||
@test batches[2][1] == batches[2].x == X[:,3:4]
|
||||
@test batches[2][2] == batches[2].y == Y[3:4]
|
||||
@test batches[3][1] == batches[3].x == X[:,5:5]
|
||||
@test batches[3][2] == batches[3].y == Y[5:5]
|
||||
|
||||
# test interaction with `train!`
|
||||
θ = ones(2)
|
||||
X = zeros(2, 10)
|
||||
|
|
Loading…
Reference in New Issue