mnist
This commit is contained in:
parent
140d50208f
commit
8cca7accf2
2
REQUIRE
2
REQUIRE
@ -6,3 +6,5 @@ NNlib
|
|||||||
ForwardDiff 0.5.0
|
ForwardDiff 0.5.0
|
||||||
Requires
|
Requires
|
||||||
Adapt
|
Adapt
|
||||||
|
GZip
|
||||||
|
Colors
|
||||||
|
@ -8,6 +8,7 @@ function __init__()
|
|||||||
mkpath(deps())
|
mkpath(deps())
|
||||||
end
|
end
|
||||||
|
|
||||||
|
include("mnist.jl")
|
||||||
include("cmudict.jl")
|
include("cmudict.jl")
|
||||||
using .CMUDict
|
using .CMUDict
|
||||||
|
|
||||||
|
84
src/data/mnist.jl
Normal file
84
src/data/mnist.jl
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
module MNIST
|
||||||
|
|
||||||
|
using GZip, Colors
|
||||||
|
|
||||||
|
const Gray = Colors.Gray{Colors.N0f8}
|
||||||
|
|
||||||
|
const dir = joinpath(@__DIR__, "../../deps/mnist")
|
||||||
|
|
||||||
|
function load()
|
||||||
|
mkpath(dir)
|
||||||
|
cd(dir) do
|
||||||
|
for file in ["train-images-idx3-ubyte",
|
||||||
|
"train-labels-idx1-ubyte",
|
||||||
|
"t10k-images-idx3-ubyte",
|
||||||
|
"t10k-labels-idx1-ubyte"]
|
||||||
|
isfile(file) && continue
|
||||||
|
download("http://yann.lecun.com/exdb/mnist/$file.gz", "$file.gz")
|
||||||
|
open(file, "w") do io
|
||||||
|
write(io, GZip.open(read, "$file.gz"))
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
const IMAGEOFFSET = 16
|
||||||
|
const LABELOFFSET = 8
|
||||||
|
|
||||||
|
const NROWS = 28
|
||||||
|
const NCOLS = 28
|
||||||
|
|
||||||
|
const TRAINIMAGES = joinpath(dir, "train-images-idx3-ubyte")
|
||||||
|
const TRAINLABELS = joinpath(dir, "train-labels-idx1-ubyte")
|
||||||
|
const TESTIMAGES = joinpath(dir, "t10k-images-idx3-ubyte")
|
||||||
|
const TESTLABELS = joinpath(dir, "t10k-labels-idx1-ubyte")
|
||||||
|
|
||||||
|
function imageheader(io::IO)
|
||||||
|
magic_number = bswap(read(io, UInt32))
|
||||||
|
total_items = bswap(read(io, UInt32))
|
||||||
|
nrows = bswap(read(io, UInt32))
|
||||||
|
ncols = bswap(read(io, UInt32))
|
||||||
|
return magic_number, Int(total_items), Int(nrows), Int(ncols)
|
||||||
|
end
|
||||||
|
|
||||||
|
function labelheader(io::IO)
|
||||||
|
magic_number = bswap(read(io, UInt32))
|
||||||
|
total_items = bswap(read(io, UInt32))
|
||||||
|
return magic_number, Int(total_items)
|
||||||
|
end
|
||||||
|
|
||||||
|
function rawimage(io::IO)
|
||||||
|
img = Array{Gray}(NCOLS, NROWS)
|
||||||
|
for i in 1:NCOLS, j in 1:NROWS
|
||||||
|
img[i, j] = reinterpret(Colors.N0f8, read(io, UInt8))
|
||||||
|
end
|
||||||
|
return img
|
||||||
|
end
|
||||||
|
|
||||||
|
function rawimage(io::IO, index::Integer)
|
||||||
|
seek(io, IMAGEOFFSET + NROWS * NCOLS * (index - 1))
|
||||||
|
return rawimage(io)
|
||||||
|
end
|
||||||
|
|
||||||
|
rawlabel(io::IO) = Int(read(io, UInt8))
|
||||||
|
|
||||||
|
function rawlabel(io::IO, index::Integer)
|
||||||
|
seek(io, LABELOFFSET + (index - 1))
|
||||||
|
return rawlabel(io)
|
||||||
|
end
|
||||||
|
|
||||||
|
getfeatures(io::IO, index::Integer) = vec(getimage(io, index))
|
||||||
|
|
||||||
|
function images(set = :train)
|
||||||
|
io = IOBuffer(read(set == :train ? TRAINIMAGES : TESTIMAGES))
|
||||||
|
_, N, nrows, ncols = imageheader(io)
|
||||||
|
[rawimage(io) for _ in 1:N]
|
||||||
|
end
|
||||||
|
|
||||||
|
function labels(set = :train)
|
||||||
|
io = IOBuffer(read(set == :train ? TRAINLABELS : TESTLABELS))
|
||||||
|
_, N = labelheader(io)
|
||||||
|
[rawlabel(io) for _ = 1:N]
|
||||||
|
end
|
||||||
|
|
||||||
|
end # module
|
Loading…
Reference in New Issue
Block a user