From 73a526b1de465b0ad893d46fce09c0536d5a0d8b Mon Sep 17 00:00:00 2001 From: Christopher Murphy <6396338+c-p-murphy@users.noreply.github.com> Date: Wed, 3 Oct 2018 12:40:24 -0400 Subject: [PATCH] reuse utils from mnist.jl --- src/data/fashion-mnist.jl | 53 +-------------------------------------- 1 file changed, 1 insertion(+), 52 deletions(-) diff --git a/src/data/fashion-mnist.jl b/src/data/fashion-mnist.jl index d608d8bb..e4510b47 100644 --- a/src/data/fashion-mnist.jl +++ b/src/data/fashion-mnist.jl @@ -1,17 +1,9 @@ module FashionMNIST -using CodecZlib, Colors - -const Gray = Colors.Gray{Colors.N0f8} +using ..MNIST: gzopen, imageheader, rawimage, labelheader, rawlabel const dir = joinpath(@__DIR__, "../../deps/fashion-mnist") -function gzopen(f, file) - open(file) do io - f(GzipDecompressorStream(io)) - end -end - function load() mkpath(dir) cd(dir) do @@ -29,53 +21,11 @@ function load() end end -const IMAGEOFFSET = 16 -const LABELOFFSET = 8 - -const NROWS = 28 -const NCOLS = 28 - const TRAINIMAGES = joinpath(dir, "train-images-idx3-ubyte") const TRAINLABELS = joinpath(dir, "train-labels-idx1-ubyte") const TESTIMAGES = joinpath(dir, "t10k-images-idx3-ubyte") const TESTLABELS = joinpath(dir, "t10k-labels-idx1-ubyte") -function imageheader(io::IO) - magic_number = bswap(read(io, UInt32)) - total_items = bswap(read(io, UInt32)) - nrows = bswap(read(io, UInt32)) - ncols = bswap(read(io, UInt32)) - return magic_number, Int(total_items), Int(nrows), Int(ncols) -end - -function labelheader(io::IO) - magic_number = bswap(read(io, UInt32)) - total_items = bswap(read(io, UInt32)) - return magic_number, Int(total_items) -end - -function rawimage(io::IO) - img = Array{Gray}(undef, NCOLS, NROWS) - for i in 1:NCOLS, j in 1:NROWS - img[i, j] = reinterpret(Colors.N0f8, read(io, UInt8)) - end - return img -end - -function rawimage(io::IO, index::Integer) - seek(io, IMAGEOFFSET + NROWS * NCOLS * (index - 1)) - return rawimage(io) -end - -rawlabel(io::IO) = Int(read(io, UInt8)) - -function rawlabel(io::IO, index::Integer) - seek(io, LABELOFFSET + (index - 1)) - return rawlabel(io) -end - -getfeatures(io::IO, index::Integer) = vec(getimage(io, index)) - """ images() images(:test) @@ -111,5 +61,4 @@ function labels(set = :train) [rawlabel(io) for _ = 1:N] end - end