DataLoader with NamedTuple
This commit is contained in:
parent
97406507fd
commit
02ee6ba426
|
@ -88,19 +88,19 @@ end
|
||||||
|
|
||||||
_nobs(data::AbstractArray) = size(data)[end]
|
_nobs(data::AbstractArray) = size(data)[end]
|
||||||
|
|
||||||
function _nobs(data::Tuple)
|
function _nobs(data::Union{Tuple,NamedTuple})
|
||||||
length(data) > 0 || throw(ArgumentError("Need at least one data input"))
|
length(data) > 0 || throw(ArgumentError("Need at least one data input"))
|
||||||
n = _nobs(data[1])
|
n = _nobs(data[1])
|
||||||
if !all(x -> _nobs(x) == n, data[2:end])
|
if !all(x -> _nobs(x) == n, Base.tail(data))
|
||||||
throw(DimensionMismatch("All data should contain same number of observations"))
|
throw(DimensionMismatch("All data should contain same number of observations"))
|
||||||
end
|
end
|
||||||
return n
|
return n
|
||||||
end
|
end
|
||||||
|
|
||||||
function _getobs(data::AbstractArray{T,N}, i) where {T,N}
|
function _getobs(data::AbstractArray{T,N}, i) where {T,N}
|
||||||
getindex(data, ntuple(i->Colon(), N-1)..., i)
|
getindex(data, ntuple(i->Colon(), Val(N-1))..., i)
|
||||||
end
|
end
|
||||||
|
|
||||||
_getobs(data::Tuple, i) = map(x -> _getobs(x, i), data)
|
_getobs(data::Union{Tuple,NamedTuple}, i) = map(x -> _getobs(x, i), data)
|
||||||
|
|
||||||
Base.eltype(d::DataLoader{D}) where D = D
|
Base.eltype(d::DataLoader{D}) where D = D
|
||||||
|
|
15
test/data.jl
15
test/data.jl
|
@ -38,6 +38,21 @@
|
||||||
@test batches[3][1] == X[:,5:5]
|
@test batches[3][1] == X[:,5:5]
|
||||||
@test batches[3][2] == Y[5:5]
|
@test batches[3][2] == Y[5:5]
|
||||||
|
|
||||||
|
# test with NamedTuple
|
||||||
|
d = DataLoader((x = X, y = Y), batchsize=2)
|
||||||
|
batches = collect(d)
|
||||||
|
@test eltype(batches) == eltype(d) == NamedTuple{(:x, :y), Tuple{typeof(X), typeof(Y)}}
|
||||||
|
@test length(batches) == 3
|
||||||
|
@test length(batches[1]) == 2
|
||||||
|
@test length(batches[2]) == 2
|
||||||
|
@test length(batches[3]) == 2
|
||||||
|
@test batches[1][1] == batches[1].x == X[:,1:2]
|
||||||
|
@test batches[1][2] == batches[1].y == Y[1:2]
|
||||||
|
@test batches[2][1] == batches[2].x == X[:,3:4]
|
||||||
|
@test batches[2][2] == batches[2].y == Y[3:4]
|
||||||
|
@test batches[3][1] == batches[3].x == X[:,5:5]
|
||||||
|
@test batches[3][2] == batches[3].y == Y[5:5]
|
||||||
|
|
||||||
# test interaction with `train!`
|
# test interaction with `train!`
|
||||||
θ = ones(2)
|
θ = ones(2)
|
||||||
X = zeros(2, 10)
|
X = zeros(2, 10)
|
||||||
|
|
Loading…
Reference in New Issue