Module vipy.data.mnist
Expand source code Browse git
import os
import numpy as np
from vipy.util import remkdir
import gzip
import struct
from array import array
import vipy.image
TRAIN_IMG_URL = 'http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz'
TRAIN_IMG_SHA1 = '6c95f4b05d2bf285e1bfb0e7960c31bd3b3f8a7d'
TRAIN_LBL_URL = 'http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz'
TRAIN_LBL_SHA1 = '2a80914081dc54586dbdf242f9805a6b8d2a15fc'
TEST_IMG_URL = 'http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz'
TEST_IMG_SHA1 = 'c3a25af1f52dad7f726cce8cacb138654b760d48'
TEST_LBL_URL = 'http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz'
TEST_LBL_SHA1 = '763e7fa3757d93b0cdec073cef058b2004252c17'
class MNIST():
def __init__(self, outdir):
"""download URLS above to outdir, then run export()"""
self.outdir = remkdir(outdir)
if not self._downloaded():
print('[vipy.data.mnist]: downloading MNIST to "%s"' % self.outdir)
self._wget()
def _downloaded(self):
gzip_downloaded = (os.path.exists(os.path.join(self.outdir, 'train-images-idx3-ubyte.gz'))
and os.path.exists(os.path.join(self.outdir, 'train-labels-idx1-ubyte.gz'))
and os.path.exists(os.path.join(self.outdir, 't10k-images-idx3-ubyte.gz'))
and os.path.exists(os.path.join(self.outdir, 't10k-labels-idx1-ubyte.gz')))
unpacked_downloaded = (os.path.exists(os.path.join(self.outdir, 'train-images-idx3-ubyte'))
and os.path.exists(os.path.join(self.outdir, 'train-labels-idx1-ubyte'))
and os.path.exists(os.path.join(self.outdir, 't10k-images-idx3-ubyte'))
and os.path.exists(os.path.join(self.outdir, 't10k-labels-idx1-ubyte')))
return (unpacked_downloaded or gzip_downloaded)
def _wget(self):
os.system('wget --directory-prefix=%s %s' % (self.outdir, TRAIN_IMG_URL))
os.system('wget --directory-prefix=%s %s' % (self.outdir, TRAIN_LBL_URL))
os.system('wget --directory-prefix=%s %s' % (self.outdir, TEST_IMG_URL))
os.system('wget --directory-prefix=%s %s' % (self.outdir, TEST_LBL_URL))
@staticmethod
def _labels(gzfile):
with gzip.open(gzfile, 'rb') as file:
magic, size = struct.unpack(">II", file.read(8))
if magic != 2049:
raise ValueError('Magic number mismatch, expected 2049,'
'got %d' % magic)
labels = array("B", file.read())
return labels
@staticmethod
def _imread(dataset, index):
"""Read MNIST encoded images, adapted from: https://github.com/sorki/python-mnist/blob/master/mnist/loader.py"""
gzfile = None
with gzip.open(gzfile, 'rb') as file:
magic, size, rows, cols = struct.unpack(">IIII", file.read(16))
if magic != 2051:
raise ValueError('Magic number mismatch, expected 2051, got %d' % magic)
file.seek(index * rows * cols + 16)
image = np.asarray(array("B", file.read(rows * cols)).tolist())
return np.reshape(image, (rows,cols))
@staticmethod
def _dataset(img_gzfile, label_gzfile, N):
y = MNIST._labels(label_gzfile).tolist()
x = []
train_img_file = img_gzfile
with gzip.open(train_img_file, 'rb') as gzfile:
magic, size, rows, cols = struct.unpack(">IIII", gzfile.read(16))
if magic != 2051:
raise ValueError('Magic number mismatch, expected 2051, got %d' % magic)
for k in range(N):
img = np.asarray(array("B", gzfile.read(rows * cols)).tolist()).reshape((rows, cols)).astype(np.uint8)
x.append(img)
return (y, np.array(x))
def trainset(self):
(labelfile, imgfile) = (os.path.join(self.outdir, 'train-labels-idx1-ubyte.gz'), os.path.join(self.outdir, 'train-images-idx3-ubyte.gz'))
return vipy.dataset.Dataset([vipy.image.ImageCategory(array=img, category=str(y), colorspace='lum') for (y,img) in zip(*self._dataset(imgfile, labelfile, 60000))], 'mnist')
def testset(self):
(labelfile, imgfile) = (os.path.join(self.outdir, 't10k-labels-idx1-ubyte.gz'), os.path.join(self.outdir, 't10k-images-idx3-ubyte.gz'))
return vipy.dataset.Dataset([vipy.image.ImageCategory(array=img, category=str(y), colorspace='lum') for (y,img) in zip(*self._dataset(imgfile, labelfile, 10000))], 'mnist_test')
Classes
class MNIST (outdir)
-
download URLS above to outdir, then run export()
Expand source code Browse git
class MNIST(): def __init__(self, outdir): """download URLS above to outdir, then run export()""" self.outdir = remkdir(outdir) if not self._downloaded(): print('[vipy.data.mnist]: downloading MNIST to "%s"' % self.outdir) self._wget() def _downloaded(self): gzip_downloaded = (os.path.exists(os.path.join(self.outdir, 'train-images-idx3-ubyte.gz')) and os.path.exists(os.path.join(self.outdir, 'train-labels-idx1-ubyte.gz')) and os.path.exists(os.path.join(self.outdir, 't10k-images-idx3-ubyte.gz')) and os.path.exists(os.path.join(self.outdir, 't10k-labels-idx1-ubyte.gz'))) unpacked_downloaded = (os.path.exists(os.path.join(self.outdir, 'train-images-idx3-ubyte')) and os.path.exists(os.path.join(self.outdir, 'train-labels-idx1-ubyte')) and os.path.exists(os.path.join(self.outdir, 't10k-images-idx3-ubyte')) and os.path.exists(os.path.join(self.outdir, 't10k-labels-idx1-ubyte'))) return (unpacked_downloaded or gzip_downloaded) def _wget(self): os.system('wget --directory-prefix=%s %s' % (self.outdir, TRAIN_IMG_URL)) os.system('wget --directory-prefix=%s %s' % (self.outdir, TRAIN_LBL_URL)) os.system('wget --directory-prefix=%s %s' % (self.outdir, TEST_IMG_URL)) os.system('wget --directory-prefix=%s %s' % (self.outdir, TEST_LBL_URL)) @staticmethod def _labels(gzfile): with gzip.open(gzfile, 'rb') as file: magic, size = struct.unpack(">II", file.read(8)) if magic != 2049: raise ValueError('Magic number mismatch, expected 2049,' 'got %d' % magic) labels = array("B", file.read()) return labels @staticmethod def _imread(dataset, index): """Read MNIST encoded images, adapted from: https://github.com/sorki/python-mnist/blob/master/mnist/loader.py""" gzfile = None with gzip.open(gzfile, 'rb') as file: magic, size, rows, cols = struct.unpack(">IIII", file.read(16)) if magic != 2051: raise ValueError('Magic number mismatch, expected 2051, got %d' % magic) file.seek(index * rows * cols + 16) image = np.asarray(array("B", file.read(rows * cols)).tolist()) return np.reshape(image, (rows,cols)) @staticmethod def _dataset(img_gzfile, label_gzfile, N): y = MNIST._labels(label_gzfile).tolist() x = [] train_img_file = img_gzfile with gzip.open(train_img_file, 'rb') as gzfile: magic, size, rows, cols = struct.unpack(">IIII", gzfile.read(16)) if magic != 2051: raise ValueError('Magic number mismatch, expected 2051, got %d' % magic) for k in range(N): img = np.asarray(array("B", gzfile.read(rows * cols)).tolist()).reshape((rows, cols)).astype(np.uint8) x.append(img) return (y, np.array(x)) def trainset(self): (labelfile, imgfile) = (os.path.join(self.outdir, 'train-labels-idx1-ubyte.gz'), os.path.join(self.outdir, 'train-images-idx3-ubyte.gz')) return vipy.dataset.Dataset([vipy.image.ImageCategory(array=img, category=str(y), colorspace='lum') for (y,img) in zip(*self._dataset(imgfile, labelfile, 60000))], 'mnist') def testset(self): (labelfile, imgfile) = (os.path.join(self.outdir, 't10k-labels-idx1-ubyte.gz'), os.path.join(self.outdir, 't10k-images-idx3-ubyte.gz')) return vipy.dataset.Dataset([vipy.image.ImageCategory(array=img, category=str(y), colorspace='lum') for (y,img) in zip(*self._dataset(imgfile, labelfile, 10000))], 'mnist_test')
Subclasses
Methods
def testset(self)
-
Expand source code Browse git
def testset(self): (labelfile, imgfile) = (os.path.join(self.outdir, 't10k-labels-idx1-ubyte.gz'), os.path.join(self.outdir, 't10k-images-idx3-ubyte.gz')) return vipy.dataset.Dataset([vipy.image.ImageCategory(array=img, category=str(y), colorspace='lum') for (y,img) in zip(*self._dataset(imgfile, labelfile, 10000))], 'mnist_test')
def trainset(self)
-
Expand source code Browse git
def trainset(self): (labelfile, imgfile) = (os.path.join(self.outdir, 'train-labels-idx1-ubyte.gz'), os.path.join(self.outdir, 'train-images-idx3-ubyte.gz')) return vipy.dataset.Dataset([vipy.image.ImageCategory(array=img, category=str(y), colorspace='lum') for (y,img) in zip(*self._dataset(imgfile, labelfile, 60000))], 'mnist')