diff --git a/src/data/mnist.jl b/src/data/mnist.jl index ba9960f9..c4cfc1b0 100644 --- a/src/data/mnist.jl +++ b/src/data/mnist.jl @@ -69,12 +69,33 @@ end getfeatures(io::IO, index::Integer) = vec(getimage(io, index)) +""" + images() + images(:test) + +Load the MNIST images. + +Each image is a 28×28 array of `Gray` colour values (see Colors.jl). + +Returns the 60,000 training images by default; pass `:test` to retreive the +10,000 test images. +""" function images(set = :train) io = IOBuffer(read(set == :train ? TRAINIMAGES : TESTIMAGES)) _, N, nrows, ncols = imageheader(io) [rawimage(io) for _ in 1:N] end +""" + labels() + labels(:test) + +Load the labels corresponding to each of the images returned from `images()`. +Each label is a number from 0-9. + +Returns the 60,000 training labels by default; pass `:test` to retreive the +10,000 test labels. +""" function labels(set = :train) io = IOBuffer(read(set == :train ? TRAINLABELS : TESTLABELS)) _, N = labelheader(io)