Module vipy.data.cifar

Expand source code Browse git
import os
import numpy as np
import vipy
import pickle


CIFAR10_URL = 'https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz'
CIFAR10_MD5 = 'c58f30108f718f92721af3b95e74349a'
CIFAR100_URL = 'https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'
CIFAR100_MD5 = 'eb9058c3a382ffc7106e4002c42a8d85'


class CIFAR10():
    """vipy.data.cifar.CIFAR10 class

    >>> D = vipy.data.cifar.CIFAR10('/path/to/outdir')
    >>> d = D.trainset()
    >>> im = d[0].mindim(512).show()

    """
    
    def __init__(self, outdir, name='cifar10', url=CIFAR10_URL, md5=CIFAR10_MD5):        
        self._datadir = vipy.util.remkdir(outdir)

        self._subdir = 'cifar-10-batches-py'
        if not os.path.exists(os.path.join(outdir, self._subdir, 'data_batch_1')):
            print('[vipy.data.cifar10]: downloading CIFAR-10 to "%s"' % self._datadir)
            vipy.downloader.download_and_unpack(url, self._datadir, md5=md5)

        self._train_archives = [os.path.join(outdir, self._subdir, 'data_batch_%d' % k) for k in range(1,6)]
        self._test_archives = [os.path.join(self._datadir, self._subdir, 'test_batch')]

        f = os.path.join(self._datadir, self._subdir, 'batches.meta')
        assert os.path.exists(f)
        with open(f, 'rb') as fo:
            d = pickle.load(fo, encoding='bytes')
        self._classes = [x.decode("utf-8") for x in d[b'label_names']]

        self._trainset()
        self._testset()

        self._name = name

    def __repr__(self):
        return '<vipy.data.%s: %s>' % (self._name, self._datadir)
    
    def classes(self):
        return self._classes

    def trainset(self):
        return vipy.dataset.Dataset([vipy.image.ImageCategory(category=self._classes[y], array=x, colorspace='rgb') for (x,y) in zip(self._trainset, self._trainlabels)], '%s_train' % self._name)
    
    def testset(self):
        return vipy.dataset.Dataset([vipy.image.ImageCategory(category=self._classes[y], array=x, colorspace='rgb') for (x,y) in zip(self._testset, self._testlabels)], '%s_test' % self._name)
    
    def _trainset(self, labelkey=b'labels'):
        (data, labels) = ([], [])
        for f in self._train_archives:
            assert os.path.exists(f)
            with open(f, 'rb') as fo:
                d = pickle.load(fo, encoding='bytes')
                data.append(d[b'data'])
                labels.append(d[labelkey])

        self._trainset = np.vstack(data)
        self._trainset = [np.transpose(x.reshape(3, 32, 32), axes=(1,2,0)) for x in self._trainset]
        self._trainlabels = [l for lbl in labels for l in lbl]
        return self
        
    def _testset(self, labelkey=b'labels'):
        (data, labels) = ([], [])
        for f in self._test_archives:
            assert os.path.exists(f)
            with open(f, 'rb') as fo:
                d = pickle.load(fo, encoding='bytes')
                data.append(d[b'data'])
                labels.append(d[labelkey])

        self._testset = np.vstack(data)
        self._testset = [np.transpose(x.reshape(3, 32, 32), axes=(1,2,0)) for x in self._testset]
        self._testlabels = [l for lbl in labels for l in lbl]        
        return self
            
            
class CIFAR100(CIFAR10):
    def __init__(self, datadir, name='cifar100', url=CIFAR100_URL, md5=CIFAR100_MD5):        

        self._name = name
        self._datadir = vipy.util.remkdir(datadir)
        self._subdir = 'cifar-100-python'
        if not os.path.exists(os.path.join(datadir, self._subdir, 'train')):
            print('[vipy.data.cifar10]: downloading CIFAR-100 to "%s"' % self._datadir)
            vipy.downloader.download_and_unpack(url, self._datadir, md5=md5)

        self._train_archives = [os.path.join(datadir, self._subdir, 'train')]
        self._test_archives = [os.path.join(datadir, self._subdir, 'test')]        

        f = os.path.join(self._datadir, self._subdir, 'meta')
        assert os.path.exists(f)
        with open(f, 'rb') as fo:
            d = pickle.load(fo, encoding='bytes')
        self._classes = [x.decode("utf-8") for x in d[b'fine_label_names']]
        self._coarse_classes = [x.decode("utf-8") for x in d[b'coarse_label_names']]        

        self._trainset()
        self._testset()
        
    def _trainset(self):
        return super()._trainset(labelkey=b'fine_labels')

    def _testset(self):
        return super()._testset(labelkey=b'fine_labels')
    
        
    

Classes

class CIFAR10 (outdir, name='cifar10', url='https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz', md5='c58f30108f718f92721af3b95e74349a')

vipy.data.cifar.CIFAR10 class

>>> D = vipy.data.cifar.CIFAR10('/path/to/outdir')
>>> d = D.trainset()
>>> im = d[0].mindim(512).show()
Expand source code Browse git
class CIFAR10():
    """vipy.data.cifar.CIFAR10 class

    >>> D = vipy.data.cifar.CIFAR10('/path/to/outdir')
    >>> d = D.trainset()
    >>> im = d[0].mindim(512).show()

    """
    
    def __init__(self, outdir, name='cifar10', url=CIFAR10_URL, md5=CIFAR10_MD5):        
        self._datadir = vipy.util.remkdir(outdir)

        self._subdir = 'cifar-10-batches-py'
        if not os.path.exists(os.path.join(outdir, self._subdir, 'data_batch_1')):
            print('[vipy.data.cifar10]: downloading CIFAR-10 to "%s"' % self._datadir)
            vipy.downloader.download_and_unpack(url, self._datadir, md5=md5)

        self._train_archives = [os.path.join(outdir, self._subdir, 'data_batch_%d' % k) for k in range(1,6)]
        self._test_archives = [os.path.join(self._datadir, self._subdir, 'test_batch')]

        f = os.path.join(self._datadir, self._subdir, 'batches.meta')
        assert os.path.exists(f)
        with open(f, 'rb') as fo:
            d = pickle.load(fo, encoding='bytes')
        self._classes = [x.decode("utf-8") for x in d[b'label_names']]

        self._trainset()
        self._testset()

        self._name = name

    def __repr__(self):
        return '<vipy.data.%s: %s>' % (self._name, self._datadir)
    
    def classes(self):
        return self._classes

    def trainset(self):
        return vipy.dataset.Dataset([vipy.image.ImageCategory(category=self._classes[y], array=x, colorspace='rgb') for (x,y) in zip(self._trainset, self._trainlabels)], '%s_train' % self._name)
    
    def testset(self):
        return vipy.dataset.Dataset([vipy.image.ImageCategory(category=self._classes[y], array=x, colorspace='rgb') for (x,y) in zip(self._testset, self._testlabels)], '%s_test' % self._name)
    
    def _trainset(self, labelkey=b'labels'):
        (data, labels) = ([], [])
        for f in self._train_archives:
            assert os.path.exists(f)
            with open(f, 'rb') as fo:
                d = pickle.load(fo, encoding='bytes')
                data.append(d[b'data'])
                labels.append(d[labelkey])

        self._trainset = np.vstack(data)
        self._trainset = [np.transpose(x.reshape(3, 32, 32), axes=(1,2,0)) for x in self._trainset]
        self._trainlabels = [l for lbl in labels for l in lbl]
        return self
        
    def _testset(self, labelkey=b'labels'):
        (data, labels) = ([], [])
        for f in self._test_archives:
            assert os.path.exists(f)
            with open(f, 'rb') as fo:
                d = pickle.load(fo, encoding='bytes')
                data.append(d[b'data'])
                labels.append(d[labelkey])

        self._testset = np.vstack(data)
        self._testset = [np.transpose(x.reshape(3, 32, 32), axes=(1,2,0)) for x in self._testset]
        self._testlabels = [l for lbl in labels for l in lbl]        
        return self

Subclasses

Methods

def classes(self)
Expand source code Browse git
def classes(self):
    return self._classes
def testset(self)
Expand source code Browse git
def testset(self):
    return vipy.dataset.Dataset([vipy.image.ImageCategory(category=self._classes[y], array=x, colorspace='rgb') for (x,y) in zip(self._testset, self._testlabels)], '%s_test' % self._name)
def trainset(self)
Expand source code Browse git
def trainset(self):
    return vipy.dataset.Dataset([vipy.image.ImageCategory(category=self._classes[y], array=x, colorspace='rgb') for (x,y) in zip(self._trainset, self._trainlabels)], '%s_train' % self._name)
class CIFAR100 (datadir, name='cifar100', url='https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz', md5='eb9058c3a382ffc7106e4002c42a8d85')

vipy.data.cifar.CIFAR10 class

>>> D = vipy.data.cifar.CIFAR10('/path/to/outdir')
>>> d = D.trainset()
>>> im = d[0].mindim(512).show()
Expand source code Browse git
class CIFAR100(CIFAR10):
    def __init__(self, datadir, name='cifar100', url=CIFAR100_URL, md5=CIFAR100_MD5):        

        self._name = name
        self._datadir = vipy.util.remkdir(datadir)
        self._subdir = 'cifar-100-python'
        if not os.path.exists(os.path.join(datadir, self._subdir, 'train')):
            print('[vipy.data.cifar10]: downloading CIFAR-100 to "%s"' % self._datadir)
            vipy.downloader.download_and_unpack(url, self._datadir, md5=md5)

        self._train_archives = [os.path.join(datadir, self._subdir, 'train')]
        self._test_archives = [os.path.join(datadir, self._subdir, 'test')]        

        f = os.path.join(self._datadir, self._subdir, 'meta')
        assert os.path.exists(f)
        with open(f, 'rb') as fo:
            d = pickle.load(fo, encoding='bytes')
        self._classes = [x.decode("utf-8") for x in d[b'fine_label_names']]
        self._coarse_classes = [x.decode("utf-8") for x in d[b'coarse_label_names']]        

        self._trainset()
        self._testset()
        
    def _trainset(self):
        return super()._trainset(labelkey=b'fine_labels')

    def _testset(self):
        return super()._testset(labelkey=b'fine_labels')

Ancestors