reuse utils from mnist.jl
This commit is contained in:
parent
95d72d7f79
commit
73a526b1de
@ -1,17 +1,9 @@
|
|||||||
module FashionMNIST
|
module FashionMNIST
|
||||||
|
|
||||||
using CodecZlib, Colors
|
using ..MNIST: gzopen, imageheader, rawimage, labelheader, rawlabel
|
||||||
|
|
||||||
const Gray = Colors.Gray{Colors.N0f8}
|
|
||||||
|
|
||||||
const dir = joinpath(@__DIR__, "../../deps/fashion-mnist")
|
const dir = joinpath(@__DIR__, "../../deps/fashion-mnist")
|
||||||
|
|
||||||
function gzopen(f, file)
|
|
||||||
open(file) do io
|
|
||||||
f(GzipDecompressorStream(io))
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function load()
|
function load()
|
||||||
mkpath(dir)
|
mkpath(dir)
|
||||||
cd(dir) do
|
cd(dir) do
|
||||||
@ -29,53 +21,11 @@ function load()
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
const IMAGEOFFSET = 16
|
|
||||||
const LABELOFFSET = 8
|
|
||||||
|
|
||||||
const NROWS = 28
|
|
||||||
const NCOLS = 28
|
|
||||||
|
|
||||||
const TRAINIMAGES = joinpath(dir, "train-images-idx3-ubyte")
|
const TRAINIMAGES = joinpath(dir, "train-images-idx3-ubyte")
|
||||||
const TRAINLABELS = joinpath(dir, "train-labels-idx1-ubyte")
|
const TRAINLABELS = joinpath(dir, "train-labels-idx1-ubyte")
|
||||||
const TESTIMAGES = joinpath(dir, "t10k-images-idx3-ubyte")
|
const TESTIMAGES = joinpath(dir, "t10k-images-idx3-ubyte")
|
||||||
const TESTLABELS = joinpath(dir, "t10k-labels-idx1-ubyte")
|
const TESTLABELS = joinpath(dir, "t10k-labels-idx1-ubyte")
|
||||||
|
|
||||||
function imageheader(io::IO)
|
|
||||||
magic_number = bswap(read(io, UInt32))
|
|
||||||
total_items = bswap(read(io, UInt32))
|
|
||||||
nrows = bswap(read(io, UInt32))
|
|
||||||
ncols = bswap(read(io, UInt32))
|
|
||||||
return magic_number, Int(total_items), Int(nrows), Int(ncols)
|
|
||||||
end
|
|
||||||
|
|
||||||
function labelheader(io::IO)
|
|
||||||
magic_number = bswap(read(io, UInt32))
|
|
||||||
total_items = bswap(read(io, UInt32))
|
|
||||||
return magic_number, Int(total_items)
|
|
||||||
end
|
|
||||||
|
|
||||||
function rawimage(io::IO)
|
|
||||||
img = Array{Gray}(undef, NCOLS, NROWS)
|
|
||||||
for i in 1:NCOLS, j in 1:NROWS
|
|
||||||
img[i, j] = reinterpret(Colors.N0f8, read(io, UInt8))
|
|
||||||
end
|
|
||||||
return img
|
|
||||||
end
|
|
||||||
|
|
||||||
function rawimage(io::IO, index::Integer)
|
|
||||||
seek(io, IMAGEOFFSET + NROWS * NCOLS * (index - 1))
|
|
||||||
return rawimage(io)
|
|
||||||
end
|
|
||||||
|
|
||||||
rawlabel(io::IO) = Int(read(io, UInt8))
|
|
||||||
|
|
||||||
function rawlabel(io::IO, index::Integer)
|
|
||||||
seek(io, LABELOFFSET + (index - 1))
|
|
||||||
return rawlabel(io)
|
|
||||||
end
|
|
||||||
|
|
||||||
getfeatures(io::IO, index::Integer) = vec(getimage(io, index))
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
images()
|
images()
|
||||||
images(:test)
|
images(:test)
|
||||||
@ -111,5 +61,4 @@ function labels(set = :train)
|
|||||||
[rawlabel(io) for _ = 1:N]
|
[rawlabel(io) for _ = 1:N]
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
||||||
end
|
end
|
||||||
|
Loading…
Reference in New Issue
Block a user