数据挖掘入门系列教程(十二)之使用keras构建CNN网络识别CIFAR10
简介
在上一篇博客:数据挖掘入门系列教程(十一点五)之CNN网络介绍中,介绍了CNN的工作原理和工作流程,在这一篇博客,将具体的使用代码来说明如何使用keras构建一个CNN网络来对CIFAR-10数据集进行训练。
如果对keras不是很熟悉的话,可以去看一看官方文档。或者看一看我前面的博客:数据挖掘入门系列教程(十一)之keras入门使用以及构建DNN网络识别MNIST,在数据挖掘入门系列教程(十一)这篇博客中使用了keras构建一个DNN网络,并对keras的做了一个入门使用介绍。
CIFAR-10数据集
CIFAR-10数据集是图像的集合,通常用于训练机器学习和计算机视觉算法。它是机器学习研究中使用比较广的数据集之一。CIFAR-10数据集包含10 种不同类别的共6w张32x32彩色图像。10个不同的类别分别代表飞机,汽车,鸟类,猫,鹿,狗,青蛙,马,轮船 和卡车。每个类别有6,000张图像
在keras恰好提供了这些数据集。加载数据集的代码如下所示:
from keras.datasets import cifar10
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
print(x_train.shape, 'x_train samples')
print(x_test.shape, 'x_test samples')
print(y_train.shape, 'y_trian samples')
print(y_test.shape, 'Y_test samples')
输出结果如下:

训练集有5w张图片,测试集有1w张图片。在\(x\)数据集中,图片是\((32,32,3)\),代表图片的大小是\(32 \times 32\),为3通道(R,G,B)的图片。
展示图片内容
我们可以稍微的展示一下图片的内容,python代码如下所示:
import matplotlib.pyplot as plt
%matplotlib inline
plt.figure(figsize=(12,10))
x, y = 8, 6
for i in range(x*y):
plt.subplot(y, x, i+1)
plt.imshow(x_train[i],interpolation='nearest')
plt.show()
下面就是数据集中的部分图片:

数据集变换
同样,我们需要将类标签进行one-hot编码:
import keras
# 将类向量转换为二进制类矩阵。
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)
实际上这一步还有很多牛逼(骚)操作,比如说对数据集进行增强,变换等等,这样都可以在一定程度上提高模型的鲁棒性,防止过拟合。这里我们就怎么简单怎么来,就只对数据集标签进行one-hot编码就行了。
构建CNN网络
构建的网络模型代码如下所示:
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten,Conv2D, MaxPooling2D
# 构建CNN网络
model = Sequential()
# 添加卷积层
model.add(Conv2D(32, (3, 3), padding='same',input_shape=x_train.shape[1:]))
# 添加激活层
model.add(Activation('relu'))
model.add(Conv2D(32, (3, 3)))
model.add(Activation('relu'))
# 添加最大池化层
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
model.add(Conv2D(64, (3, 3), padding='same'))
model.add(Activation('relu'))
model.add(Conv2D(64, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25))
# 将上一层输出的数据变成一维
model.add(Flatten())
# 添加全连接层
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(10))
model.add(Activation('softmax'))
# 网络模型的介绍
print(model.summary())
这里解释一下代码:
Conv2D
Conv2D代表2D的卷积层,可能这里会有人问,我的图片不是3通道(RGB)的吗?为什么使用的是Conv2D而不是Conv3D。首先先说明,在Conv2D中的这个“2”代表的是卷积层可以在两个维度(也就是width,length)进行移动。那么同理Conv3D中的“3”代表这个卷积层可以在3个维度进行移动(比如说视频中的width ,length,time)。那么针对RGB这种3通道(channels),卷积过程中输入有多少个通道,则滤波器(卷积核)就有多少个通道。
简单点来说就是:
输入
单色图片的input,是2D, \(w \times h\)
彩色图片的input,是3D,\(w \times h \times channels\)
卷积核filter
单色图片的filter,是2D, \(w \times h\)
彩色图片的filter,是3D, \(w \times h \times channels\)
值得注意的是,卷积之后的结果是二维的。(因为会将3维卷积得到的结果进行相加)
接着继续解释Conv2D的参数:
Conv2D(32, (3, 3), padding='same',input_shape=x_train.shape[1:])
32表示的是输出空间的维度(也就是filter滤波器的输出数量)(3,3)代表的是卷积核的大小strides(这里没有用到):这个代表是滑动的步长。input_shape:输入的维度,这里是(28,28,3)
padding在上一篇博客介绍过,在keras中有两个取值:"valid" 或 "same" (大小写敏感)。
- valid padding:不进行任何处理,只使用原始图像,不允许卷积核超出原始图像边界
- same padding:进行填充,允许卷积核超出原始图像边界,并使得卷积后结果的大小与原来的一致

Flatten
Flatten这一层就是为了将多维数据变成一维数据:

构建网络
from keras.optimizers import RMSprop
# 利用 RMSprop 来训练模型。
model.compile(loss='categorical_crossentropy',
optimizer=RMSprop(),
metrics=['accuracy']
)
其他的参数在上两篇博客中已经讲了,就不再赘述。
进行训练评估
这里大家可以根据自己的电脑配置适当调整一下batch_size的大小。
history = model.fit(x_train, y_train,
batch_size=32,
epochs=64,
verbose=1,
validation_data=(x_test, y_test)
)
在i5-10代u,mx250的情况下,训练一轮大概需要27s左右。
训练完成之后,进行评估:
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
结果如下所示:

这个结果可以说的上是一言难尽,
数据挖掘入门系列教程(十二)之使用keras构建CNN网络识别CIFAR10的更多相关文章
- 数据挖掘入门系列教程(二)之分类问题OneR算法
数据挖掘入门系列教程(二)之分类问题OneR算法 数据挖掘入门系列博客:https://www.cnblogs.com/xiaohuiduan/category/1661541.html 项目地址:G ...
- 数据挖掘入门系列教程(十一)之keras入门使用以及构建DNN网络识别MNIST
简介 在上一篇博客:数据挖掘入门系列教程(十点五)之DNN介绍及公式推导中,详细的介绍了DNN,并对其进行了公式推导.本来这篇博客是准备直接介绍CNN的,但是想了一下,觉得还是使用keras构建一个D ...
- 数据挖掘入门系列教程(三)之scikit-learn框架基本使用(以K近邻算法为例)
数据挖掘入门系列教程(三)之scikit-learn框架基本使用(以K近邻算法为例) 简介 scikit-learn 估计器 加载数据集 进行fit训练 设置参数 预处理 流水线 结尾 数据挖掘入门系 ...
- 数据挖掘入门系列教程(八)之使用神经网络(基于pybrain)识别数字手写集MNIST
目录 数据挖掘入门系列教程(八)之使用神经网络(基于pybrain)识别数字手写集MNIST 下载数据集 加载数据集 构建神经网络 反向传播(BP)算法 进行预测 F1验证 总结 参考 数据挖掘入门系 ...
- 数据挖掘入门系列教程(九)之基于sklearn的SVM使用
目录 介绍 基于SVM对MINIST数据集进行分类 使用SVM SVM分析垃圾邮件 加载数据集 分词 构建词云 构建数据集 进行训练 交叉验证 炼丹术 总结 参考 介绍 在上一篇博客:数据挖掘入门系列 ...
- CRL快速开发框架系列教程十二(MongoDB支持)
本系列目录 CRL快速开发框架系列教程一(Code First数据表不需再关心) CRL快速开发框架系列教程二(基于Lambda表达式查询) CRL快速开发框架系列教程三(更新数据) CRL快速开发框 ...
- webpack4 系列教程(十二):处理第三方JavaScript库
教程所示图片使用的是 github 仓库图片,网速过慢的朋友请移步<webpack4 系列教程(十二):处理第三方 JavaScript 库>原文地址.或者来我的小站看更多内容:godbm ...
- 数据挖掘入门系列教程(四)之基于scikit-lean实现决策树
目录 数据挖掘入门系列教程(四)之基于scikit-lean决策树处理Iris 加载数据集 数据特征 训练 随机森林 调参工程师 结尾 数据挖掘入门系列教程(四)之基于scikit-lean决策树处理 ...
- 数据挖掘入门系列教程(四点五)之Apriori算法
目录 数据挖掘入门系列教程(四点五)之Apriori算法 频繁(项集)数据的评判标准 Apriori 算法流程 结尾 数据挖掘入门系列教程(四点五)之Apriori算法 Apriori(先验)算法关联 ...
随机推荐
- 通俗易懂.NET GC垃圾回收机制(适用于小白面试,大牛勿喷)
情景:你接到xx公司面试邀请,你怀着激动忐忑的心坐在对方公司会议室,想着等会的技术面试.技术总监此时走来,与你简单交谈后.... 技术:你对GC垃圾回收机制了解的怎么样? 你:还行,有简单了解过. 技 ...
- CCF2018 12 2题,小明终于到家了
最近在愁着备考,拿CCF刷题,就遇到这个难题,最后搜索了一下大佬们的方法,终于解决, 问题描述 一次放学的时候,小明已经规划好了自己回家的路线,并且能够预测经过各个路段的时间.同时,小明通过学校里安装 ...
- Redis 笔记(二)—— STRING 常用命令
字符串中不仅仅可以存储字符串,它可以存储以下 3 中类型的值 : 字符串 整数 浮点数 Redis 可以对字符串进行截取等相关操作,对整数.浮点数进行增减操作. 自增自减命令 命令 用例和描述 INC ...
- spring07
关于spring的泛型依赖注入主要是继承等方面的知识 具体实现的简单的代码如下: package bao1; public class BaseRepository <T>{ } pack ...
- Mysql大数据量问题与解决
今日格言:了解了为什么,问题就解决了一半. Mysql 单表适合的最大数据量是多少? 我们说 Mysql 单表适合存储的最大数据量,自然不是说能够存储的最大数据量,如果是说能够存储的最大量,那么,如果 ...
- 7.1 java 类、(成员)变量、(成员)方法
/* * 面向对象思想: * 面向对象是基于面向过程的编程思想. * * 面向过程:强调的是每一个功能的步骤 * 面向对象:强调的是对象,然后由对象去调用功能 * * 面向对象的思想特点: * A:是 ...
- Tomcat启动过程原理详解 -- 非常的报错:涉及了2个web.xml等文件的加载流程
Tomcat启动过程原理详解 发表于: Tomcat, Web Server, 旧文存档 | 作者: 谋万世全局者 标签: Tomcat,原理,启动过程,详解 基于Java的Web 应用程序是 ser ...
- Mac Jenkins+fastlane 简单几步实现iOS自动化打包发布 + jenkins节点设置
最近在使用jenkins 实现ios自动化打包发布蒲公英过程实践遇到了一些坑,特意记录下来方便有需要的人. 进入正题: 一.安装Jenkins 1.Mac上安装Jenkins 遇到到坑 因为 Jenk ...
- CSS盒模型属性详细介绍
一.概述 CSS盒模型是定义元素周围的间隔.尺寸.外边距.边框以及文本内容和边框之间内边距的一组属性的集合. 示例代码: <!DOCTYPE html> <html lang=&qu ...
- S - Primitive Primes CodeForces - 1316C 数学
数学题 在f(x)和g(x)的系数里找到第一个不是p的倍数的数,然后相加就是答案 为什么? 设x1为f(x)中第一个不是p的倍数的系数,x2为g(x)...... x1+x2前的系数为(a[x1+x2 ...