简介

在上一篇博客:数据挖掘入门系列教程(十一点五)之CNN网络介绍中,介绍了CNN的工作原理和工作流程,在这一篇博客,将具体的使用代码来说明如何使用keras构建一个CNN网络来对CIFAR-10数据集进行训练。

如果对keras不是很熟悉的话,可以去看一看官方文档。或者看一看我前面的博客:数据挖掘入门系列教程(十一)之keras入门使用以及构建DNN网络识别MNIST,在数据挖掘入门系列教程(十一)这篇博客中使用了keras构建一个DNN网络,并对keras的做了一个入门使用介绍。

CIFAR-10数据集

CIFAR-10数据集是图像的集合,通常用于训练机器学习和计算机视觉算法。它是机器学习研究中使用比较广的数据集之一。CIFAR-10数据集包含10 种不同类别的共6w张32x32彩色图像。10个不同的类别分别代表飞机,汽车,鸟类,猫,鹿,狗,青蛙,马,轮船 和卡车。每个类别有6,000张图像

在keras恰好提供了这些数据集。加载数据集的代码如下所示:

from keras.datasets import cifar10

(x_train, y_train), (x_test, y_test) = cifar10.load_data()

print(x_train.shape, 'x_train samples')
print(x_test.shape, 'x_test samples')
print(y_train.shape, 'y_trian samples')
print(y_test.shape, 'Y_test samples')

输出结果如下:

训练集有5w张图片,测试集有1w张图片。在\(x\)数据集中,图片是\((32,32,3)\),代表图片的大小是\(32 \times 32\),为3通道(R,G,B)的图片。

展示图片内容

我们可以稍微的展示一下图片的内容,python代码如下所示:

import matplotlib.pyplot as plt
%matplotlib inline plt.figure(figsize=(12,10))
x, y = 8, 6 for i in range(x*y):
plt.subplot(y, x, i+1)
plt.imshow(x_train[i],interpolation='nearest')
plt.show()

下面就是数据集中的部分图片:

数据集变换

同样,我们需要将类标签进行one-hot编码:

import keras
# 将类向量转换为二进制类矩阵。
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)

实际上这一步还有很多牛逼(骚)操作,比如说对数据集进行增强,变换等等,这样都可以在一定程度上提高模型的鲁棒性,防止过拟合。这里我们就怎么简单怎么来,就只对数据集标签进行one-hot编码就行了。

构建CNN网络

构建的网络模型代码如下所示:

from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten,Conv2D, MaxPooling2D # 构建CNN网络
model = Sequential() # 添加卷积层
model.add(Conv2D(32, (3, 3), padding='same',input_shape=x_train.shape[1:]))
# 添加激活层
model.add(Activation('relu'))
model.add(Conv2D(32, (3, 3)))
model.add(Activation('relu')) # 添加最大池化层
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25)) model.add(Conv2D(64, (3, 3), padding='same'))
model.add(Activation('relu'))
model.add(Conv2D(64, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25)) # 将上一层输出的数据变成一维
model.add(Flatten())
# 添加全连接层
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(10))
model.add(Activation('softmax')) # 网络模型的介绍
print(model.summary())

这里解释一下代码:

Conv2D

Conv2D代表2D的卷积层,可能这里会有人问,我的图片不是3通道(RGB)的吗?为什么使用的是Conv2D而不是Conv3D。首先先说明,在Conv2D中的这个“2”代表的是卷积层可以在两个维度(也就是width,length)进行移动。那么同理Conv3D中的“3”代表这个卷积层可以在3个维度进行移动(比如说视频中的width ,length,time)。那么针对RGB这种3通道(channels),卷积过程中输入有多少个通道,则滤波器(卷积核)就有多少个通道。

简单点来说就是:

输入

单色图片的input,是2D, \(w \times h\)

彩色图片的input,是3D,\(w \times h \times channels\)

卷积核filter

单色图片的filter,是2D, \(w \times h\)

彩色图片的filter,是3D, \(w \times h \times channels\)

值得注意的是,卷积之后的结果是二维的。(因为会将3维卷积得到的结果进行相加)

接着继续解释Conv2D的参数:

Conv2D(32, (3, 3), padding='same',input_shape=x_train.shape[1:])

  • 32表示的是输出空间的维度(也就是filter滤波器的输出数量)
  • (3,3)代表的是卷积核的大小
  • strides(这里没有用到):这个代表是滑动的步长。
  • input_shape:输入的维度,这里是(28,28,3)

padding在上一篇博客介绍过,在keras中有两个取值:"valid""same" (大小写敏感)。

  • valid padding:不进行任何处理,只使用原始图像,不允许卷积核超出原始图像边界
  • same padding:进行填充,允许卷积核超出原始图像边界,并使得卷积后结果的大小与原来的一致

Flatten

Flatten这一层就是为了将多维数据变成一维数据:

构建网络

from keras.optimizers import RMSprop
# 利用 RMSprop 来训练模型。
model.compile(loss='categorical_crossentropy',
optimizer=RMSprop(),
metrics=['accuracy']
)

其他的参数在上两篇博客中已经讲了,就不再赘述。

进行训练评估

这里大家可以根据自己的电脑配置适当调整一下batch_size的大小。

history = model.fit(x_train, y_train,
batch_size=32,
epochs=64,
verbose=1,
validation_data=(x_test, y_test)
)

在i5-10代u,mx250的情况下,训练一轮大概需要27s左右。

训练完成之后,进行评估:

score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

结果如下所示:

这个结果可以说的上是一言难尽,

数据挖掘入门系列教程(十二)之使用keras构建CNN网络识别CIFAR10的更多相关文章

  1. 数据挖掘入门系列教程(二)之分类问题OneR算法

    数据挖掘入门系列教程(二)之分类问题OneR算法 数据挖掘入门系列博客:https://www.cnblogs.com/xiaohuiduan/category/1661541.html 项目地址:G ...

  2. 数据挖掘入门系列教程(十一)之keras入门使用以及构建DNN网络识别MNIST

    简介 在上一篇博客:数据挖掘入门系列教程(十点五)之DNN介绍及公式推导中,详细的介绍了DNN,并对其进行了公式推导.本来这篇博客是准备直接介绍CNN的,但是想了一下,觉得还是使用keras构建一个D ...

  3. 数据挖掘入门系列教程(三)之scikit-learn框架基本使用(以K近邻算法为例)

    数据挖掘入门系列教程(三)之scikit-learn框架基本使用(以K近邻算法为例) 简介 scikit-learn 估计器 加载数据集 进行fit训练 设置参数 预处理 流水线 结尾 数据挖掘入门系 ...

  4. 数据挖掘入门系列教程(八)之使用神经网络(基于pybrain)识别数字手写集MNIST

    目录 数据挖掘入门系列教程(八)之使用神经网络(基于pybrain)识别数字手写集MNIST 下载数据集 加载数据集 构建神经网络 反向传播(BP)算法 进行预测 F1验证 总结 参考 数据挖掘入门系 ...

  5. 数据挖掘入门系列教程(九)之基于sklearn的SVM使用

    目录 介绍 基于SVM对MINIST数据集进行分类 使用SVM SVM分析垃圾邮件 加载数据集 分词 构建词云 构建数据集 进行训练 交叉验证 炼丹术 总结 参考 介绍 在上一篇博客:数据挖掘入门系列 ...

  6. CRL快速开发框架系列教程十二(MongoDB支持)

    本系列目录 CRL快速开发框架系列教程一(Code First数据表不需再关心) CRL快速开发框架系列教程二(基于Lambda表达式查询) CRL快速开发框架系列教程三(更新数据) CRL快速开发框 ...

  7. webpack4 系列教程(十二):处理第三方JavaScript库

    教程所示图片使用的是 github 仓库图片,网速过慢的朋友请移步<webpack4 系列教程(十二):处理第三方 JavaScript 库>原文地址.或者来我的小站看更多内容:godbm ...

  8. 数据挖掘入门系列教程(四)之基于scikit-lean实现决策树

    目录 数据挖掘入门系列教程(四)之基于scikit-lean决策树处理Iris 加载数据集 数据特征 训练 随机森林 调参工程师 结尾 数据挖掘入门系列教程(四)之基于scikit-lean决策树处理 ...

  9. 数据挖掘入门系列教程(四点五)之Apriori算法

    目录 数据挖掘入门系列教程(四点五)之Apriori算法 频繁(项集)数据的评判标准 Apriori 算法流程 结尾 数据挖掘入门系列教程(四点五)之Apriori算法 Apriori(先验)算法关联 ...

随机推荐

  1. 2020年Java多线程与并发系列22道高频面试题(附思维导图和答案解析)

    前言 现在不管是大公司还是小公司,去面试都会问到多线程与并发编程的知识,大家面试的时候这方面的知识一定要提前做好储备. 关于多线程与并发的知识总结了一个思维导图,分享给大家 1.Java中实现多线程有 ...

  2. css中(单冒号):after和(双冒号)::after的区别

    :after是伪元素 .::after是伪类 什么是伪类? CSS伪类是添加到选择器的关键字,指定要选择的元素的特殊状态.例如,:hover可被用于在用户将鼠标悬停在按钮上时改变按钮的颜色. ::af ...

  3. python--django中一些有用的工具引入路径

    django.shortcuts render:渲染前端页面 redirect:跳转到其他页面 django forms:表单验证 # 定义 class Form(forms.Form): name ...

  4. 微信小程序placeholder设置自定义颜色

    原地址链接:https://blog.csdn.net/august_leo/article/details/80877382 这是微信小程序input组件的官方文档描述,下图红框里的placehol ...

  5. python 函数--递归函数

    一.递归函数的定义:在一个函数里面调用函数本身 python限制最大层数:998层 def foo(n): print(n) n+=1 foo(n) foo(1)

  6. python3(二十四) subClas

    """ 继承的多态 """ __author__ = 'shaozhiqi' # -----------------父类---------- ...

  7. Java相同id的数据集合,合并数据为一条,并将几个字段内容合并为一个

    Java实现,当然也可以数据库实现; /** * Created by shaozhiqi on 2019/7/31. */ public class TestUnion { @Test public ...

  8. Linux c++ vim环境搭建系列(4)——vim插件安装配置使用

    4. 插件 主要是c++相关的. ~/.vimrc文件在GitHub上有:https://github.com/whuwzp/vim_config 以下内容参考: https://github.com ...

  9. Python分析数据难吗?某科技大学教授说,很难但有方法就简单

    用python分析数据难吗?某科技大学的教授这样说,很难,但要讲方法,主要是因为并不是掌握了基础,就能用python来做数据分析的. 所谓python的基础,也就是刚入门的python学习者,学习的基 ...

  10. 学习Salesforce | Einstein业务机会评分怎么玩

    Einstein 业务机会评分(Opportunity Scoring)是销售团队的得力助手,通过分数以及研究影响分数的因素,确定业务机会的优先级,赢得更多交易. Einstein 业务机会评分可以给 ...