Theano入门——CIFAR-10和CIFAR-100数据集

1.CIFAR-10数据集介绍

CIFAR-10数据集包含60000个32*32的彩色图像,共有10类。有50000个训练图像和10000个测试图像。
数据集分为5个训练块和1个测试块,每个块有10000个图像。测试块包含从每类随机选择的1000个图像。训练块以随机的顺序包含这些图像,但一些训练块可能比其它类包含更多的图像。训练块每类包含5000个图像。

类间完全互斥。汽车和卡车类没有重叠。“Automobile”只包含sedans,SUVs等等。“Truck”只包含大卡车。两者都不包含皮卡车。

2.CIFAR-10数据集Python版本

存档包含文件data_batch_1,data_batch_2,...,data_batch_5和test_batch。每个文件都是1个Python"pickled"对象。按下面方式加载并返回1个字典:

  1.  
    def unpickle(file):
  2.  
    import cPickle
  3.  
    fo = open(file, 'rb')
  4.  
    dict = cPickle.load(fo)
  5.  
    fo.close()
  6.  
    return dict

每个块文件包含1个带有如下元素的字典:
data——1个10000*3072大小的uint8s数组。数组的每行存储1张32*32的图像。第1个1024包含红色通道值,下1个包含绿色,最后的1024包含蓝色。图像存储以行顺序为主,所以数组的前32列为图像第1行的红色通道值。
labels——1个10000数的范围为0~9的列表。索引i的数值表示数组data中第i个图像的标签。

数据集中包含另外1个叫batches.meta的文件。它也包含1个Python字典对象。有如下列元素:

label_names——1个10元素的列表,给labels中的数值标签以有意义的名称。例如,label_names[0] == "airplane", label_names[1] == "automobile"等。

3.CIFAR-100数据集

数据集包含100小类,每小类包含600个图像,其中有500个训练图像和100个测试图像。100类被分组为20个大类。每个图像带有1个小类的“fine”标签和1个大类“coarse”标签。

4.CIFAR-100数据集Python版本

同CIFAR-10数据集Python版本。

5.CIFAR-10代码实现

(1)CIFAR-10数据集存放在相对文件路径data_dir_cifar10下。
(2)_load_batch_cifar10函数
该函数加载CIFAR-10格式的块文件。根据块文件名filename和相对文件路径data_dir_cifar10拼接得到块文件位置。用numpy中的load函数加载(用cPickle中的load函数也可以加载)返回batch,batch是1个字典,里面包含数据和标签。根据数据的索引'data'得到图像数据,根据标签的索引'labels'得到图像分类的标签,标签转换为one-hot编码形式,见前一篇文章对MNIST数据集的说明。最后把数据和标签中的元素的数据类型统一为dtype类型。
(3)concatenate函数
该函数当axis=0时将矩阵按行顺序从上往下摆放(列长度相等),当axis=1时将矩阵按列顺序从左往右摆放(行长度相等)。

(4)_grayscale函数
该函数首先将a变形为4维张量,维数为(a.shape[0],3,32,32)。之前a为矩阵形式,a的每行代表1个图片样本,a的列为图片中的所有像素按照红,绿,蓝的顺序排序的结果,即依次为图片所有像素红色通道值,图片所有像素绿色通道值,图片所有像素蓝色通道值。3表示颜色通道个数,32表示图片的行数和列数。可以这样理解,reshape函数根据先把二维的矩阵a的所有行排成1行,先把这1行切出a.shape[0]个行b,然后对行b切3份,每份都为行c,对行c切32份,每份为行d,再对行d切32份,每个元素为像素的单通道值。
mean(1)对应的是第2个轴(颜色通道轴),对第2个轴求平均值,即将三通道值求平均,最后第2个轴变为单通道(灰度通道)。最后变形得到2维矩阵(a.shape[0],32*32)。
(5)cifar10函数
cifar函数先调用_load_batch_cifar10函数读取块文件,返回的x和t都是列表形式。所以可以用append连接,连接后每个块文件里面的东西是用中括号括起来的,所以用concatenate函数把外面区分块文件类型的括号去掉。此时得到的x_train的行为图像样本,列为像素的红,绿,蓝通道值;t_train的行为图像标签,列为标签的one-hot编码值。x_test和t_test的结构同理。转换为灰度图像后x_train和x_test为矩阵,行为图像样本,列为像素归一化后的灰度值;t_train和t_test为矩阵,行为图像标签,列为每位one-hot编码值。

  1.  
    import numpy as np
  2.  
    import os
  3.  
    import cPickle as pickle
  4.  
    import glob
  5.  
    import matplotlib.pyplot as plt
  6.  
     
  7.  
    data_dir = "data"
  8.  
    data_dir_cifar10 = os.path.join(data_dir, "cifar-10-batches-py")
  9.  
    data_dir_cifar100 = os.path.join(data_dir, "cifar-100-python")
  10.  
     
  11.  
    class_names_cifar10 = np.load(os.path.join(data_dir_cifar10, "batches.meta"))
  12.  
    class_names_cifar100 = np.load(os.path.join(data_dir_cifar100, "meta"))
  13.  
     
  14.  
     
  15.  
    def one_hot(x, n):
  16.  
    """
  17.  
    convert index representation to one-hot representation
  18.  
    """
  19.  
    x = np.array(x)
  20.  
    assert x.ndim == 1
  21.  
    return np.eye(n)[x]
  22.  
     
  23.  
    def _load_batch_cifar10(filename, dtype='float64'):
  24.  
    """
  25.  
    load a batch in the CIFAR-10 format
  26.  
    """
  27.  
    path = os.path.join(data_dir_cifar10, filename)
  28.  
    batch = np.load(path)
  29.  
    data = batch['data'] / 255.0 # scale between [0, 1]
  30.  
    labels = one_hot(batch['labels'], n=10) # convert labels to one-hot representation
  31.  
    return data.astype(dtype), labels.astype(dtype)
  32.  
     
  33.  
     
  34.  
    def _grayscale(a):
  35.  
    print a.reshape(a.shape[0], 3, 32, 32).mean(1).reshape(a.shape[0], -1)
  36.  
    return a.reshape(a.shape[0], 3, 32, 32).mean(1).reshape(a.shape[0], -1)
  37.  
     
  38.  
     
  39.  
    def cifar10(dtype='float64', grayscale=True):
  40.  
    # train
  41.  
    x_train = []
  42.  
    t_train = []
  43.  
    for k in xrange(5):
  44.  
    x, t = _load_batch_cifar10("data_batch_%d" % (k + 1), dtype=dtype)
  45.  
    x_train.append(x)
  46.  
    t_train.append(t)
  47.  
     
  48.  
    x_train = np.concatenate(x_train, axis=0)
  49.  
    t_train = np.concatenate(t_train, axis=0)
  50.  
     
  51.  
    # test
  52.  
    x_test, t_test = _load_batch_cifar10("test_batch", dtype=dtype)
  53.  
     
  54.  
    if grayscale:
  55.  
    x_train = _grayscale(x_train)
  56.  
    x_test = _grayscale(x_test)
  57.  
     
  58.  
    return x_train, t_train, x_test, t_test
  59.  
     
  60.  
     
  61.  
    def _load_batch_cifar100(filename, dtype='float64'):
  62.  
    """
  63.  
    load a batch in the CIFAR-100 format
  64.  
    """
  65.  
    path = os.path.join(data_dir_cifar100, filename)
  66.  
    batch = np.load(path)
  67.  
    data = batch['data'] / 255.0
  68.  
    labels = one_hot(batch['fine_labels'], n=100)
  69.  
    return data.astype(dtype), labels.astype(dtype)
  70.  
     
  71.  
     
  72.  
    def cifar100(dtype='float64', grayscale=True):
  73.  
    x_train, t_train = _load_batch_cifar100("train", dtype=dtype)
  74.  
    x_test, t_test = _load_batch_cifar100("test", dtype=dtype)
  75.  
     
  76.  
    if grayscale:
  77.  
    x_train = _grayscale(x_train)
  78.  
    x_test = _grayscale(x_test)
  79.  
     
  80.  
    return x_train, t_train, x_test, t_test
  81.  
     
  82.  
    Xtrain, Ytrain, Xtest, Ytest = cifar10()
  83.  
    ################################################
  84.  
     
  85.  
    # 图像样本显示
  86.  
     
  87.  
    image = Xtrain[0].reshape(32, 32)
  88.  
    image1 = Xtrain[255].reshape(32, 32)
  89.  
     
  90.  
    fig = plt.figure()
  91.  
    ax = fig.add_subplot(121)
  92.  
    plt.axis('off')
  93.  
    plt.title(class_names_cifar10['label_names'][list(Ytrain[0]).index(1)])
  94.  
    plt.imshow(image, cmap='gray')
  95.  
     
  96.  
    ax = fig.add_subplot(122)
  97.  
    plt.title(class_names_cifar10['label_names'][list(Ytrain[255]).index(1)])
  98.  
    plt.imshow(image1, cmap='gray')
  99.  
    plt.axis('off')
  100.  
    plt.show()

6.实验结果

7.参考链接

(1)CIFAR数据集:http://www.cs.toronto.edu/~kriz/cifar.html
(2)数据集加载:https://github.com/benanne/theano-tutorial/blob/master/load.py

Theano入门——CIFAR-10和CIFAR-100数据集的更多相关文章

  1. 【翻译】TensorFlow卷积神经网络识别CIFAR 10Convolutional Neural Network (CNN)| CIFAR 10 TensorFlow

    原网址:https://data-flair.training/blogs/cnn-tensorflow-cifar-10/ by DataFlair Team · Published May 21, ...

  2. Spring入门(10)-Spring JDBC

    Spring入门(10)-Spring JDBC 0. 目录 JdbcTemplate介绍 JdbcTemplate常见方法 代码示例 参考资料 1. JdbcTemplate介绍 JdbcTempl ...

  3. 【SSRS】入门篇(三) -- 为报表定义数据集

    原文:[SSRS]入门篇(三) -- 为报表定义数据集 通过前两篇文件 [SSRS]入门篇(一) -- 创建SSRS项目 和 [SSRS]入门篇(二) -- 建立数据源 后, 我们建立了一个SSRS项 ...

  4. python入门(10)使用List和tuple

    python入门(10)使用List和tuple list Python内置的一种数据类型是列表:list.list是一种有序的集合,可以随时添加和删除其中的元素. 比如,列出班里所有同学的名字,就可 ...

  5. Linux pwn入门教程(10)——针对函数重定位流程的几种攻击

    作者:Tangerine@SAINTSEC 本系列的最后一篇 感谢各位看客的支持 感谢原作者的付出一直以来都有读者向笔者咨询教程系列问题,奈何该系列并非笔者所写[笔者仅为代发]且笔者功底薄弱,故无法解 ...

  6. 【JavaScript】随机生成10个0~100的数字

    随机生成10个0~100不重复的数字(包含0和100): 需要用到的知识点:随机数 去重 下面放代码 <!DOCTYPE html> <html> <head> & ...

  7. CSS,让100%的宽度,自动减10,让100%的高度,自动减10,可以加减乘除

    CSS,让100%的宽度,自动减10,让100%的高度,自动减10,可以加减乘除 实例: .add{ width: calc(100% - 10px); height: calc(100% - 10p ...

  8. [易学易懂系列|rustlang语言|零基础|快速入门|(10)|Vectors容器]

    [易学易懂系列|rustlang语言|零基础|快速入门|(10)] 有意思的基础知识 Vectors 我们之前知道array数组是定长,只可我保存相同类型的数据的数据类型. 如果,我们想用不定长的数组 ...

  9. MySql中varchar(10)和varchar(100)的区别

    背景 许多使用MySQL的同学都会使用到varchar这个数据类型.初学者刚开始学习varchar时,一定记得varchar是个变长的类型这个知识点,所以很多初学者在设计表时,就会把varchar(X ...

随机推荐

  1. 第十篇:javaScript中的JSON总结

    参考网站:json中国,MDN json 一.必知基础    JSON 是JavaScript对象文字符号的一个子集,它可以自如的在JavaScript中使用.看下这个对象: var myJSONOb ...

  2. vue项目的实用配置

    文件压缩如何去掉console 在使用vue开发项目的过程中,免不了在调试的时候会写许多console,在控制台进行调试:在开发的时候这种输出是必须的,但是build后线上运行时这个东西是不能出现的: ...

  3. python collections模块 之 defaultdict

    defaultdict 是 dict 的子类,因此 defaultdict 也可被当成 dict 来使用,dict 支持的功能,defaultdict 基本都支持.但它与 dict 最大的区别在于,如 ...

  4. Mybatis-generator/通用Mapper/Mybatis-Plus对比

    mybatis-plus-boot-starter和mybatis-spring-boot-starter冲突导致MapperScan失效问题还没有解决,只能不用mybatis-plus-boot-s ...

  5. pom.xml文件配置maven仓库地址

    中央仓库就是Maven的一个默认的远程仓库,Maven的安装文件中自带了中央仓库的配置($M2_HOME/lib/maven-model-builder.jar) 在很多情况下,默认的中央仓库无法满足 ...

  6. 论一个X倒下了千千万万个X站起来了

        一个人倒下了,千千万万个人站起来了.起源自革命时期的标语.后来又被应用于各种激励的场景.     这句话在当时环境是好的,但是在无数人应用在不同场景,并且这些场景都不怎么好的时候.人的普遍思维 ...

  7. [转]Entity Framework 的实体关系

    通过 Entiy Framework实践系列 文章,理了理 Entity Framework 的实体关系. 为什么要写文章来理清这些关系?“血”的教训啊,刚开始使用 Entity Framework  ...

  8. Java事件监听机制与观察者设计模式

    一. Java事件监听机制 1. 事件监听三要素: 事件源,事件对象,事件监听器 2. 三要素之间的关系:事件源注册事件监听器后,当事件源上发生某个动作时,事件源就会调用事件监听的一个方法,并将事件对 ...

  9. Python3基础笔记_迭代器

    # Python3 迭代器与生成器 import sys ''' 迭代是Python最强大的功能之一,是访问集合元素的一种方式. 迭代器是一个可以记住遍历的位置的对象. 迭代器对象从集合的第一个元素开 ...

  10. c语言学习笔记 - 二进制文件

    在进行文件操作的时候,有时候是用文本的形式存在文件里面,例如用 fprintf(fp,"%d",123) 存一个数据123,实际的存储是已1,2,3这3个ASCII码存入,打开文件 ...