fix type instability
This commit is contained in:
parent
89191bdeb1
commit
14e7181c7c
|
@ -1,7 +1,7 @@
|
|||
# Adapted from Knet's src/data.jl (author: Deniz Yuret)
|
||||
|
||||
struct DataLoader
|
||||
data
|
||||
struct DataLoader{D}
|
||||
data::D
|
||||
batchsize::Int
|
||||
nobs::Int
|
||||
partial::Bool
|
||||
|
@ -24,7 +24,7 @@ If `partial=false`, drops the last mini-batch if it is smaller than the batchsiz
|
|||
|
||||
The original data is preserved in the `data` field of the DataLoader.
|
||||
|
||||
Example usage:
|
||||
Usage example:
|
||||
|
||||
Xtrain = rand(10, 100)
|
||||
train_loader = DataLoader(Xtrain, batchsize=2)
|
||||
|
@ -98,5 +98,10 @@ function _nobs(data::Tuple)
|
|||
return n
|
||||
end
|
||||
|
||||
_getobs(data::AbstractArray, i) = data[(Base.Colon() for _=1:ndims(data)-1)..., i]
|
||||
function _getobs(data::A, i) where A<:AbstractArray{T,N} where {T,N}
|
||||
getindex(data, ntuple(i->Colon(), N-1)..., i)
|
||||
end
|
||||
|
||||
_getobs(data::Tuple, i) = ((_getobs(x, i) for x in data)...,)
|
||||
|
||||
Base.eltype(d::DataLoader{D}) where D = D
|
|
@ -4,6 +4,7 @@
|
|||
|
||||
d = DataLoader(X, batchsize=2)
|
||||
batches = collect(d)
|
||||
@test eltype(batches) == eltype(d) == typeof(X)
|
||||
@test length(batches) == 3
|
||||
@test batches[1] == X[:,1:2]
|
||||
@test batches[2] == X[:,3:4]
|
||||
|
@ -11,18 +12,21 @@
|
|||
|
||||
d = DataLoader(X, batchsize=2, partial=false)
|
||||
batches = collect(d)
|
||||
@test eltype(batches) == eltype(d) == typeof(X)
|
||||
@test length(batches) == 2
|
||||
@test batches[1] == X[:,1:2]
|
||||
@test batches[2] == X[:,3:4]
|
||||
|
||||
d = DataLoader((X,), batchsize=2, partial=false)
|
||||
batches = collect(d)
|
||||
@test eltype(batches) == eltype(d) == Tuple{typeof(X)}
|
||||
@test length(batches) == 2
|
||||
@test batches[1] == (X[:,1:2],)
|
||||
@test batches[2] == (X[:,3:4],)
|
||||
|
||||
d = DataLoader((X, Y), batchsize=2)
|
||||
batches = collect(d)
|
||||
@test eltype(batches) == eltype(d) == Tuple{typeof(X), typeof(Y)}
|
||||
@test length(batches) == 3
|
||||
@test length(batches[1]) == 2
|
||||
@test length(batches[2]) == 2
|
||||
|
|
|
@ -7,44 +7,40 @@ using IterTools: ncycle
|
|||
|
||||
Random.seed!(0)
|
||||
|
||||
@testset "Flux" begin
|
||||
@testset "Utils" begin
|
||||
include("utils.jl")
|
||||
end
|
||||
|
||||
@testset "Utils" begin
|
||||
include("utils.jl")
|
||||
@testset "Onehot" begin
|
||||
include("onehot.jl")
|
||||
end
|
||||
|
||||
@testset "Optimise" begin
|
||||
include("optimise.jl")
|
||||
end
|
||||
|
||||
@testset "Data" begin
|
||||
include("data.jl")
|
||||
end
|
||||
|
||||
@testset "Layers" begin
|
||||
include("layers/basic.jl")
|
||||
include("layers/normalisation.jl")
|
||||
include("layers/stateless.jl")
|
||||
include("layers/conv.jl")
|
||||
end
|
||||
|
||||
@testset "CUDA" begin
|
||||
if Flux.use_cuda[]
|
||||
include("cuda/cuda.jl")
|
||||
else
|
||||
@warn "CUDA unavailable, not testing GPU support"
|
||||
end
|
||||
end
|
||||
|
||||
@testset "Onehot" begin
|
||||
include("onehot.jl")
|
||||
@testset "Docs" begin
|
||||
if VERSION >= v"1.4"
|
||||
DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true)
|
||||
doctest(Flux)
|
||||
end
|
||||
|
||||
@testset "Optimise" begin
|
||||
include("optimise.jl")
|
||||
end
|
||||
|
||||
@testset "Data" begin
|
||||
include("data.jl")
|
||||
end
|
||||
|
||||
@testset "Layers" begin
|
||||
include("layers/basic.jl")
|
||||
include("layers/normalisation.jl")
|
||||
include("layers/stateless.jl")
|
||||
include("layers/conv.jl")
|
||||
end
|
||||
|
||||
@testset "CUDA" begin
|
||||
if Flux.use_cuda[]
|
||||
include("cuda/cuda.jl")
|
||||
else
|
||||
@warn "CUDA unavailable, not testing GPU support"
|
||||
end
|
||||
end
|
||||
|
||||
@testset "Docs" begin
|
||||
if VERSION >= v"1.4"
|
||||
DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true)
|
||||
doctest(Flux)
|
||||
end
|
||||
end
|
||||
|
||||
end # testset Flux
|
||||
end
|
Loading…
Reference in New Issue