diff --git a/REQUIRE b/REQUIRE index 8e718a92..eb7545da 100644 --- a/REQUIRE +++ b/REQUIRE @@ -6,3 +6,5 @@ NNlib ForwardDiff 0.5.0 Requires Adapt +GZip +Colors diff --git a/src/data/Data.jl b/src/data/Data.jl index ffea729c..2844d0ae 100644 --- a/src/data/Data.jl +++ b/src/data/Data.jl @@ -8,6 +8,7 @@ function __init__() mkpath(deps()) end +include("mnist.jl") include("cmudict.jl") using .CMUDict diff --git a/src/data/mnist.jl b/src/data/mnist.jl new file mode 100644 index 00000000..ba9960f9 --- /dev/null +++ b/src/data/mnist.jl @@ -0,0 +1,84 @@ +module MNIST + +using GZip, Colors + +const Gray = Colors.Gray{Colors.N0f8} + +const dir = joinpath(@__DIR__, "../../deps/mnist") + +function load() + mkpath(dir) + cd(dir) do + for file in ["train-images-idx3-ubyte", + "train-labels-idx1-ubyte", + "t10k-images-idx3-ubyte", + "t10k-labels-idx1-ubyte"] + isfile(file) && continue + download("http://yann.lecun.com/exdb/mnist/$file.gz", "$file.gz") + open(file, "w") do io + write(io, GZip.open(read, "$file.gz")) + end + end + end +end + +const IMAGEOFFSET = 16 +const LABELOFFSET = 8 + +const NROWS = 28 +const NCOLS = 28 + +const TRAINIMAGES = joinpath(dir, "train-images-idx3-ubyte") +const TRAINLABELS = joinpath(dir, "train-labels-idx1-ubyte") +const TESTIMAGES = joinpath(dir, "t10k-images-idx3-ubyte") +const TESTLABELS = joinpath(dir, "t10k-labels-idx1-ubyte") + +function imageheader(io::IO) + magic_number = bswap(read(io, UInt32)) + total_items = bswap(read(io, UInt32)) + nrows = bswap(read(io, UInt32)) + ncols = bswap(read(io, UInt32)) + return magic_number, Int(total_items), Int(nrows), Int(ncols) +end + +function labelheader(io::IO) + magic_number = bswap(read(io, UInt32)) + total_items = bswap(read(io, UInt32)) + return magic_number, Int(total_items) +end + +function rawimage(io::IO) + img = Array{Gray}(NCOLS, NROWS) + for i in 1:NCOLS, j in 1:NROWS + img[i, j] = reinterpret(Colors.N0f8, read(io, UInt8)) + end + return img +end + +function rawimage(io::IO, index::Integer) + seek(io, IMAGEOFFSET + NROWS * NCOLS * (index - 1)) + return rawimage(io) +end + +rawlabel(io::IO) = Int(read(io, UInt8)) + +function rawlabel(io::IO, index::Integer) + seek(io, LABELOFFSET + (index - 1)) + return rawlabel(io) +end + +getfeatures(io::IO, index::Integer) = vec(getimage(io, index)) + +function images(set = :train) + io = IOBuffer(read(set == :train ? TRAINIMAGES : TESTIMAGES)) + _, N, nrows, ncols = imageheader(io) + [rawimage(io) for _ in 1:N] +end + +function labels(set = :train) + io = IOBuffer(read(set == :train ? TRAINLABELS : TESTLABELS)) + _, N = labelheader(io) + [rawlabel(io) for _ = 1:N] +end + +end # module