fix type instability
This commit is contained in:
parent
89191bdeb1
commit
14e7181c7c
@ -1,7 +1,7 @@
|
|||||||
# Adapted from Knet's src/data.jl (author: Deniz Yuret)
|
# Adapted from Knet's src/data.jl (author: Deniz Yuret)
|
||||||
|
|
||||||
struct DataLoader
|
struct DataLoader{D}
|
||||||
data
|
data::D
|
||||||
batchsize::Int
|
batchsize::Int
|
||||||
nobs::Int
|
nobs::Int
|
||||||
partial::Bool
|
partial::Bool
|
||||||
@ -24,7 +24,7 @@ If `partial=false`, drops the last mini-batch if it is smaller than the batchsiz
|
|||||||
|
|
||||||
The original data is preserved in the `data` field of the DataLoader.
|
The original data is preserved in the `data` field of the DataLoader.
|
||||||
|
|
||||||
Example usage:
|
Usage example:
|
||||||
|
|
||||||
Xtrain = rand(10, 100)
|
Xtrain = rand(10, 100)
|
||||||
train_loader = DataLoader(Xtrain, batchsize=2)
|
train_loader = DataLoader(Xtrain, batchsize=2)
|
||||||
@ -98,5 +98,10 @@ function _nobs(data::Tuple)
|
|||||||
return n
|
return n
|
||||||
end
|
end
|
||||||
|
|
||||||
_getobs(data::AbstractArray, i) = data[(Base.Colon() for _=1:ndims(data)-1)..., i]
|
function _getobs(data::A, i) where A<:AbstractArray{T,N} where {T,N}
|
||||||
|
getindex(data, ntuple(i->Colon(), N-1)..., i)
|
||||||
|
end
|
||||||
|
|
||||||
_getobs(data::Tuple, i) = ((_getobs(x, i) for x in data)...,)
|
_getobs(data::Tuple, i) = ((_getobs(x, i) for x in data)...,)
|
||||||
|
|
||||||
|
Base.eltype(d::DataLoader{D}) where D = D
|
@ -4,6 +4,7 @@
|
|||||||
|
|
||||||
d = DataLoader(X, batchsize=2)
|
d = DataLoader(X, batchsize=2)
|
||||||
batches = collect(d)
|
batches = collect(d)
|
||||||
|
@test eltype(batches) == eltype(d) == typeof(X)
|
||||||
@test length(batches) == 3
|
@test length(batches) == 3
|
||||||
@test batches[1] == X[:,1:2]
|
@test batches[1] == X[:,1:2]
|
||||||
@test batches[2] == X[:,3:4]
|
@test batches[2] == X[:,3:4]
|
||||||
@ -11,18 +12,21 @@
|
|||||||
|
|
||||||
d = DataLoader(X, batchsize=2, partial=false)
|
d = DataLoader(X, batchsize=2, partial=false)
|
||||||
batches = collect(d)
|
batches = collect(d)
|
||||||
|
@test eltype(batches) == eltype(d) == typeof(X)
|
||||||
@test length(batches) == 2
|
@test length(batches) == 2
|
||||||
@test batches[1] == X[:,1:2]
|
@test batches[1] == X[:,1:2]
|
||||||
@test batches[2] == X[:,3:4]
|
@test batches[2] == X[:,3:4]
|
||||||
|
|
||||||
d = DataLoader((X,), batchsize=2, partial=false)
|
d = DataLoader((X,), batchsize=2, partial=false)
|
||||||
batches = collect(d)
|
batches = collect(d)
|
||||||
|
@test eltype(batches) == eltype(d) == Tuple{typeof(X)}
|
||||||
@test length(batches) == 2
|
@test length(batches) == 2
|
||||||
@test batches[1] == (X[:,1:2],)
|
@test batches[1] == (X[:,1:2],)
|
||||||
@test batches[2] == (X[:,3:4],)
|
@test batches[2] == (X[:,3:4],)
|
||||||
|
|
||||||
d = DataLoader((X, Y), batchsize=2)
|
d = DataLoader((X, Y), batchsize=2)
|
||||||
batches = collect(d)
|
batches = collect(d)
|
||||||
|
@test eltype(batches) == eltype(d) == Tuple{typeof(X), typeof(Y)}
|
||||||
@test length(batches) == 3
|
@test length(batches) == 3
|
||||||
@test length(batches[1]) == 2
|
@test length(batches[1]) == 2
|
||||||
@test length(batches[2]) == 2
|
@test length(batches[2]) == 2
|
||||||
|
@ -7,44 +7,40 @@ using IterTools: ncycle
|
|||||||
|
|
||||||
Random.seed!(0)
|
Random.seed!(0)
|
||||||
|
|
||||||
@testset "Flux" begin
|
@testset "Utils" begin
|
||||||
|
include("utils.jl")
|
||||||
|
end
|
||||||
|
|
||||||
@testset "Utils" begin
|
@testset "Onehot" begin
|
||||||
include("utils.jl")
|
include("onehot.jl")
|
||||||
|
end
|
||||||
|
|
||||||
|
@testset "Optimise" begin
|
||||||
|
include("optimise.jl")
|
||||||
|
end
|
||||||
|
|
||||||
|
@testset "Data" begin
|
||||||
|
include("data.jl")
|
||||||
|
end
|
||||||
|
|
||||||
|
@testset "Layers" begin
|
||||||
|
include("layers/basic.jl")
|
||||||
|
include("layers/normalisation.jl")
|
||||||
|
include("layers/stateless.jl")
|
||||||
|
include("layers/conv.jl")
|
||||||
|
end
|
||||||
|
|
||||||
|
@testset "CUDA" begin
|
||||||
|
if Flux.use_cuda[]
|
||||||
|
include("cuda/cuda.jl")
|
||||||
|
else
|
||||||
|
@warn "CUDA unavailable, not testing GPU support"
|
||||||
end
|
end
|
||||||
|
end
|
||||||
|
|
||||||
@testset "Onehot" begin
|
@testset "Docs" begin
|
||||||
include("onehot.jl")
|
if VERSION >= v"1.4"
|
||||||
|
DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true)
|
||||||
|
doctest(Flux)
|
||||||
end
|
end
|
||||||
|
end
|
||||||
@testset "Optimise" begin
|
|
||||||
include("optimise.jl")
|
|
||||||
end
|
|
||||||
|
|
||||||
@testset "Data" begin
|
|
||||||
include("data.jl")
|
|
||||||
end
|
|
||||||
|
|
||||||
@testset "Layers" begin
|
|
||||||
include("layers/basic.jl")
|
|
||||||
include("layers/normalisation.jl")
|
|
||||||
include("layers/stateless.jl")
|
|
||||||
include("layers/conv.jl")
|
|
||||||
end
|
|
||||||
|
|
||||||
@testset "CUDA" begin
|
|
||||||
if Flux.use_cuda[]
|
|
||||||
include("cuda/cuda.jl")
|
|
||||||
else
|
|
||||||
@warn "CUDA unavailable, not testing GPU support"
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
@testset "Docs" begin
|
|
||||||
if VERSION >= v"1.4"
|
|
||||||
DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true)
|
|
||||||
doctest(Flux)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
end # testset Flux
|
|
Loading…
Reference in New Issue
Block a user