简介

在上一篇博客:数据挖掘入门系列教程(十一点五)之CNN网络介绍中,介绍了CNN的工作原理和工作流程,在这一篇博客,将具体的使用代码来说明如何使用keras构建一个CNN网络来对CIFAR-10数据集进行训练。

如果对keras不是很熟悉的话,可以去看一看官方文档。或者看一看我前面的博客:数据挖掘入门系列教程(十一)之keras入门使用以及构建DNN网络识别MNIST,在数据挖掘入门系列教程(十一)这篇博客中使用了keras构建一个DNN网络,并对keras的做了一个入门使用介绍。

CIFAR-10数据集

CIFAR-10数据集是图像的集合,通常用于训练机器学习和计算机视觉算法。它是机器学习研究中使用比较广的数据集之一。CIFAR-10数据集包含10 种不同类别的共6w张32x32彩色图像。10个不同的类别分别代表飞机,汽车,鸟类,猫,鹿,狗,青蛙,马,轮船 和卡车。每个类别有6,000张图像

在keras恰好提供了这些数据集。加载数据集的代码如下所示:

from keras.datasets import cifar10

(x_train, y_train), (x_test, y_test) = cifar10.load_data()

print(x_train.shape, 'x_train samples')
print(x_test.shape, 'x_test samples')
print(y_train.shape, 'y_trian samples')
print(y_test.shape, 'Y_test samples')

输出结果如下:

训练集有5w张图片,测试集有1w张图片。在\(x\)数据集中,图片是\((32,32,3)\),代表图片的大小是\(32 \times 32\),为3通道(R,G,B)的图片。

展示图片内容

我们可以稍微的展示一下图片的内容,python代码如下所示:

import matplotlib.pyplot as plt
%matplotlib inline plt.figure(figsize=(12,10))
x, y = 8, 6 for i in range(x*y):
plt.subplot(y, x, i+1)
plt.imshow(x_train[i],interpolation='nearest')
plt.show()

下面就是数据集中的部分图片:

数据集变换

同样,我们需要将类标签进行one-hot编码:

import keras
# 将类向量转换为二进制类矩阵。
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)

实际上这一步还有很多牛逼(骚)操作,比如说对数据集进行增强,变换等等,这样都可以在一定程度上提高模型的鲁棒性,防止过拟合。这里我们就怎么简单怎么来,就只对数据集标签进行one-hot编码就行了。

构建CNN网络

构建的网络模型代码如下所示:

from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten,Conv2D, MaxPooling2D # 构建CNN网络
model = Sequential() # 添加卷积层
model.add(Conv2D(32, (3, 3), padding='same',input_shape=x_train.shape[1:]))
# 添加激活层
model.add(Activation('relu'))
model.add(Conv2D(32, (3, 3)))
model.add(Activation('relu')) # 添加最大池化层
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25)) model.add(Conv2D(64, (3, 3), padding='same'))
model.add(Activation('relu'))
model.add(Conv2D(64, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25)) # 将上一层输出的数据变成一维
model.add(Flatten())
# 添加全连接层
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(10))
model.add(Activation('softmax')) # 网络模型的介绍
print(model.summary())

这里解释一下代码:

Conv2D

Conv2D代表2D的卷积层,可能这里会有人问,我的图片不是3通道(RGB)的吗?为什么使用的是Conv2D而不是Conv3D。首先先说明,在Conv2D中的这个“2”代表的是卷积层可以在两个维度(也就是width,length)进行移动。那么同理Conv3D中的“3”代表这个卷积层可以在3个维度进行移动(比如说视频中的width ,length,time)。那么针对RGB这种3通道(channels),卷积过程中输入有多少个通道,则滤波器(卷积核)就有多少个通道。

简单点来说就是:

输入

单色图片的input,是2D, \(w \times h\)

彩色图片的input,是3D,\(w \times h \times channels\)

卷积核filter

单色图片的filter,是2D, \(w \times h\)

彩色图片的filter,是3D, \(w \times h \times channels\)

值得注意的是,卷积之后的结果是二维的。(因为会将3维卷积得到的结果进行相加)

接着继续解释Conv2D的参数:

Conv2D(32, (3, 3), padding='same',input_shape=x_train.shape[1:])

  • 32表示的是输出空间的维度(也就是filter滤波器的输出数量)
  • (3,3)代表的是卷积核的大小
  • strides(这里没有用到):这个代表是滑动的步长。
  • input_shape:输入的维度,这里是(28,28,3)

padding在上一篇博客介绍过,在keras中有两个取值:"valid""same" (大小写敏感)。

  • valid padding:不进行任何处理,只使用原始图像,不允许卷积核超出原始图像边界
  • same padding:进行填充,允许卷积核超出原始图像边界,并使得卷积后结果的大小与原来的一致

Flatten

Flatten这一层就是为了将多维数据变成一维数据:

构建网络

from keras.optimizers import RMSprop
# 利用 RMSprop 来训练模型。
model.compile(loss='categorical_crossentropy',
optimizer=RMSprop(),
metrics=['accuracy']
)

其他的参数在上两篇博客中已经讲了,就不再赘述。

进行训练评估

这里大家可以根据自己的电脑配置适当调整一下batch_size的大小。

history = model.fit(x_train, y_train,
batch_size=32,
epochs=64,
verbose=1,
validation_data=(x_test, y_test)
)

在i5-10代u,mx250的情况下,训练一轮大概需要27s左右。

训练完成之后,进行评估:

score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

结果如下所示:

这个结果可以说的上是一言难尽,

数据挖掘入门系列教程(十二)之使用keras构建CNN网络识别CIFAR10的更多相关文章

  1. 数据挖掘入门系列教程(二)之分类问题OneR算法

    数据挖掘入门系列教程(二)之分类问题OneR算法 数据挖掘入门系列博客:https://www.cnblogs.com/xiaohuiduan/category/1661541.html 项目地址:G ...

  2. 数据挖掘入门系列教程(十一)之keras入门使用以及构建DNN网络识别MNIST

    简介 在上一篇博客:数据挖掘入门系列教程(十点五)之DNN介绍及公式推导中,详细的介绍了DNN,并对其进行了公式推导.本来这篇博客是准备直接介绍CNN的,但是想了一下,觉得还是使用keras构建一个D ...

  3. 数据挖掘入门系列教程(三)之scikit-learn框架基本使用(以K近邻算法为例)

    数据挖掘入门系列教程(三)之scikit-learn框架基本使用(以K近邻算法为例) 简介 scikit-learn 估计器 加载数据集 进行fit训练 设置参数 预处理 流水线 结尾 数据挖掘入门系 ...

  4. 数据挖掘入门系列教程(八)之使用神经网络(基于pybrain)识别数字手写集MNIST

    目录 数据挖掘入门系列教程(八)之使用神经网络(基于pybrain)识别数字手写集MNIST 下载数据集 加载数据集 构建神经网络 反向传播(BP)算法 进行预测 F1验证 总结 参考 数据挖掘入门系 ...

  5. 数据挖掘入门系列教程(九)之基于sklearn的SVM使用

    目录 介绍 基于SVM对MINIST数据集进行分类 使用SVM SVM分析垃圾邮件 加载数据集 分词 构建词云 构建数据集 进行训练 交叉验证 炼丹术 总结 参考 介绍 在上一篇博客:数据挖掘入门系列 ...

  6. CRL快速开发框架系列教程十二(MongoDB支持)

    本系列目录 CRL快速开发框架系列教程一(Code First数据表不需再关心) CRL快速开发框架系列教程二(基于Lambda表达式查询) CRL快速开发框架系列教程三(更新数据) CRL快速开发框 ...

  7. webpack4 系列教程(十二):处理第三方JavaScript库

    教程所示图片使用的是 github 仓库图片,网速过慢的朋友请移步<webpack4 系列教程(十二):处理第三方 JavaScript 库>原文地址.或者来我的小站看更多内容:godbm ...

  8. 数据挖掘入门系列教程(四)之基于scikit-lean实现决策树

    目录 数据挖掘入门系列教程(四)之基于scikit-lean决策树处理Iris 加载数据集 数据特征 训练 随机森林 调参工程师 结尾 数据挖掘入门系列教程(四)之基于scikit-lean决策树处理 ...

  9. 数据挖掘入门系列教程(四点五)之Apriori算法

    目录 数据挖掘入门系列教程(四点五)之Apriori算法 频繁(项集)数据的评判标准 Apriori 算法流程 结尾 数据挖掘入门系列教程(四点五)之Apriori算法 Apriori(先验)算法关联 ...

随机推荐

  1. 通俗易懂.NET GC垃圾回收机制(适用于小白面试,大牛勿喷)

    情景:你接到xx公司面试邀请,你怀着激动忐忑的心坐在对方公司会议室,想着等会的技术面试.技术总监此时走来,与你简单交谈后.... 技术:你对GC垃圾回收机制了解的怎么样? 你:还行,有简单了解过. 技 ...

  2. CCF2018 12 2题,小明终于到家了

    最近在愁着备考,拿CCF刷题,就遇到这个难题,最后搜索了一下大佬们的方法,终于解决, 问题描述 一次放学的时候,小明已经规划好了自己回家的路线,并且能够预测经过各个路段的时间.同时,小明通过学校里安装 ...

  3. Redis 笔记(二)—— STRING 常用命令

    字符串中不仅仅可以存储字符串,它可以存储以下 3 中类型的值 : 字符串 整数 浮点数 Redis 可以对字符串进行截取等相关操作,对整数.浮点数进行增减操作. 自增自减命令 命令 用例和描述 INC ...

  4. spring07

    关于spring的泛型依赖注入主要是继承等方面的知识 具体实现的简单的代码如下: package bao1; public class BaseRepository <T>{ } pack ...

  5. Mysql大数据量问题与解决

    今日格言:了解了为什么,问题就解决了一半. Mysql 单表适合的最大数据量是多少? 我们说 Mysql 单表适合存储的最大数据量,自然不是说能够存储的最大数据量,如果是说能够存储的最大量,那么,如果 ...

  6. 7.1 java 类、(成员)变量、(成员)方法

    /* * 面向对象思想: * 面向对象是基于面向过程的编程思想. * * 面向过程:强调的是每一个功能的步骤 * 面向对象:强调的是对象,然后由对象去调用功能 * * 面向对象的思想特点: * A:是 ...

  7. Tomcat启动过程原理详解 -- 非常的报错:涉及了2个web.xml等文件的加载流程

    Tomcat启动过程原理详解 发表于: Tomcat, Web Server, 旧文存档 | 作者: 谋万世全局者 标签: Tomcat,原理,启动过程,详解 基于Java的Web 应用程序是 ser ...

  8. Mac Jenkins+fastlane 简单几步实现iOS自动化打包发布 + jenkins节点设置

    最近在使用jenkins 实现ios自动化打包发布蒲公英过程实践遇到了一些坑,特意记录下来方便有需要的人. 进入正题: 一.安装Jenkins 1.Mac上安装Jenkins 遇到到坑 因为 Jenk ...

  9. CSS盒模型属性详细介绍

    一.概述 CSS盒模型是定义元素周围的间隔.尺寸.外边距.边框以及文本内容和边框之间内边距的一组属性的集合. 示例代码: <!DOCTYPE html> <html lang=&qu ...

  10. S - Primitive Primes CodeForces - 1316C 数学

    数学题 在f(x)和g(x)的系数里找到第一个不是p的倍数的数,然后相加就是答案 为什么? 设x1为f(x)中第一个不是p的倍数的系数,x2为g(x)...... x1+x2前的系数为(a[x1+x2 ...