简介

在上一篇博客:数据挖掘入门系列教程(十一点五)之CNN网络介绍中,介绍了CNN的工作原理和工作流程,在这一篇博客,将具体的使用代码来说明如何使用keras构建一个CNN网络来对CIFAR-10数据集进行训练。

如果对keras不是很熟悉的话,可以去看一看官方文档。或者看一看我前面的博客:数据挖掘入门系列教程(十一)之keras入门使用以及构建DNN网络识别MNIST,在数据挖掘入门系列教程(十一)这篇博客中使用了keras构建一个DNN网络,并对keras的做了一个入门使用介绍。

CIFAR-10数据集

CIFAR-10数据集是图像的集合,通常用于训练机器学习和计算机视觉算法。它是机器学习研究中使用比较广的数据集之一。CIFAR-10数据集包含10 种不同类别的共6w张32x32彩色图像。10个不同的类别分别代表飞机,汽车,鸟类,猫,鹿,狗,青蛙,马,轮船 和卡车。每个类别有6,000张图像

在keras恰好提供了这些数据集。加载数据集的代码如下所示:

from keras.datasets import cifar10

(x_train, y_train), (x_test, y_test) = cifar10.load_data()

print(x_train.shape, 'x_train samples')
print(x_test.shape, 'x_test samples')
print(y_train.shape, 'y_trian samples')
print(y_test.shape, 'Y_test samples')

输出结果如下:

训练集有5w张图片,测试集有1w张图片。在\(x\)数据集中,图片是\((32,32,3)\),代表图片的大小是\(32 \times 32\),为3通道(R,G,B)的图片。

展示图片内容

我们可以稍微的展示一下图片的内容,python代码如下所示:

import matplotlib.pyplot as plt
%matplotlib inline plt.figure(figsize=(12,10))
x, y = 8, 6 for i in range(x*y):
plt.subplot(y, x, i+1)
plt.imshow(x_train[i],interpolation='nearest')
plt.show()

下面就是数据集中的部分图片:

数据集变换

同样,我们需要将类标签进行one-hot编码:

import keras
# 将类向量转换为二进制类矩阵。
y_train = keras.utils.to_categorical(y_train, 10)
y_test = keras.utils.to_categorical(y_test, 10)

实际上这一步还有很多牛逼(骚)操作,比如说对数据集进行增强,变换等等,这样都可以在一定程度上提高模型的鲁棒性,防止过拟合。这里我们就怎么简单怎么来,就只对数据集标签进行one-hot编码就行了。

构建CNN网络

构建的网络模型代码如下所示:

from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten,Conv2D, MaxPooling2D # 构建CNN网络
model = Sequential() # 添加卷积层
model.add(Conv2D(32, (3, 3), padding='same',input_shape=x_train.shape[1:]))
# 添加激活层
model.add(Activation('relu'))
model.add(Conv2D(32, (3, 3)))
model.add(Activation('relu')) # 添加最大池化层
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25)) model.add(Conv2D(64, (3, 3), padding='same'))
model.add(Activation('relu'))
model.add(Conv2D(64, (3, 3)))
model.add(Activation('relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Dropout(0.25)) # 将上一层输出的数据变成一维
model.add(Flatten())
# 添加全连接层
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(10))
model.add(Activation('softmax')) # 网络模型的介绍
print(model.summary())

这里解释一下代码:

Conv2D

Conv2D代表2D的卷积层,可能这里会有人问,我的图片不是3通道(RGB)的吗?为什么使用的是Conv2D而不是Conv3D。首先先说明,在Conv2D中的这个“2”代表的是卷积层可以在两个维度(也就是width,length)进行移动。那么同理Conv3D中的“3”代表这个卷积层可以在3个维度进行移动(比如说视频中的width ,length,time)。那么针对RGB这种3通道(channels),卷积过程中输入有多少个通道,则滤波器(卷积核)就有多少个通道。

简单点来说就是:

输入

单色图片的input,是2D, \(w \times h\)

彩色图片的input,是3D,\(w \times h \times channels\)

卷积核filter

单色图片的filter,是2D, \(w \times h\)

彩色图片的filter,是3D, \(w \times h \times channels\)

值得注意的是,卷积之后的结果是二维的。(因为会将3维卷积得到的结果进行相加)

接着继续解释Conv2D的参数:

Conv2D(32, (3, 3), padding='same',input_shape=x_train.shape[1:])

  • 32表示的是输出空间的维度(也就是filter滤波器的输出数量)
  • (3,3)代表的是卷积核的大小
  • strides(这里没有用到):这个代表是滑动的步长。
  • input_shape:输入的维度,这里是(28,28,3)

padding在上一篇博客介绍过,在keras中有两个取值:"valid""same" (大小写敏感)。

  • valid padding:不进行任何处理,只使用原始图像,不允许卷积核超出原始图像边界
  • same padding:进行填充,允许卷积核超出原始图像边界,并使得卷积后结果的大小与原来的一致

Flatten

Flatten这一层就是为了将多维数据变成一维数据:

构建网络

from keras.optimizers import RMSprop
# 利用 RMSprop 来训练模型。
model.compile(loss='categorical_crossentropy',
optimizer=RMSprop(),
metrics=['accuracy']
)

其他的参数在上两篇博客中已经讲了,就不再赘述。

进行训练评估

这里大家可以根据自己的电脑配置适当调整一下batch_size的大小。

history = model.fit(x_train, y_train,
batch_size=32,
epochs=64,
verbose=1,
validation_data=(x_test, y_test)
)

在i5-10代u,mx250的情况下,训练一轮大概需要27s左右。

训练完成之后,进行评估:

score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

结果如下所示:

这个结果可以说的上是一言难尽,

数据挖掘入门系列教程(十二)之使用keras构建CNN网络识别CIFAR10的更多相关文章

  1. 数据挖掘入门系列教程(二)之分类问题OneR算法

    数据挖掘入门系列教程(二)之分类问题OneR算法 数据挖掘入门系列博客:https://www.cnblogs.com/xiaohuiduan/category/1661541.html 项目地址:G ...

  2. 数据挖掘入门系列教程(十一)之keras入门使用以及构建DNN网络识别MNIST

    简介 在上一篇博客:数据挖掘入门系列教程(十点五)之DNN介绍及公式推导中,详细的介绍了DNN,并对其进行了公式推导.本来这篇博客是准备直接介绍CNN的,但是想了一下,觉得还是使用keras构建一个D ...

  3. 数据挖掘入门系列教程(三)之scikit-learn框架基本使用(以K近邻算法为例)

    数据挖掘入门系列教程(三)之scikit-learn框架基本使用(以K近邻算法为例) 简介 scikit-learn 估计器 加载数据集 进行fit训练 设置参数 预处理 流水线 结尾 数据挖掘入门系 ...

  4. 数据挖掘入门系列教程(八)之使用神经网络(基于pybrain)识别数字手写集MNIST

    目录 数据挖掘入门系列教程(八)之使用神经网络(基于pybrain)识别数字手写集MNIST 下载数据集 加载数据集 构建神经网络 反向传播(BP)算法 进行预测 F1验证 总结 参考 数据挖掘入门系 ...

  5. 数据挖掘入门系列教程(九)之基于sklearn的SVM使用

    目录 介绍 基于SVM对MINIST数据集进行分类 使用SVM SVM分析垃圾邮件 加载数据集 分词 构建词云 构建数据集 进行训练 交叉验证 炼丹术 总结 参考 介绍 在上一篇博客:数据挖掘入门系列 ...

  6. CRL快速开发框架系列教程十二(MongoDB支持)

    本系列目录 CRL快速开发框架系列教程一(Code First数据表不需再关心) CRL快速开发框架系列教程二(基于Lambda表达式查询) CRL快速开发框架系列教程三(更新数据) CRL快速开发框 ...

  7. webpack4 系列教程(十二):处理第三方JavaScript库

    教程所示图片使用的是 github 仓库图片,网速过慢的朋友请移步<webpack4 系列教程(十二):处理第三方 JavaScript 库>原文地址.或者来我的小站看更多内容:godbm ...

  8. 数据挖掘入门系列教程(四)之基于scikit-lean实现决策树

    目录 数据挖掘入门系列教程(四)之基于scikit-lean决策树处理Iris 加载数据集 数据特征 训练 随机森林 调参工程师 结尾 数据挖掘入门系列教程(四)之基于scikit-lean决策树处理 ...

  9. 数据挖掘入门系列教程(四点五)之Apriori算法

    目录 数据挖掘入门系列教程(四点五)之Apriori算法 频繁(项集)数据的评判标准 Apriori 算法流程 结尾 数据挖掘入门系列教程(四点五)之Apriori算法 Apriori(先验)算法关联 ...

随机推荐

  1. Docker基础修炼1--Docker简介及快速入门体验

    本文作为Docker基础系列第一篇文章,将详细阐述和分析三个问题:Docker是什么?为什么要用Docker?如何快速掌握Docker技术? 本系列文章中Docker的用法演示是基于CentOS7进行 ...

  2. uni-app在线引入阿里字体图标库

    第一步 在app.vue中引入阿里字体图标库 第二步 在任意页面使用就可以了 <view class="item" v-for="(value,index) in ...

  3. Vulnhub homeless靶机渗透

    信息搜集 nmap -sP 192.168.146.6 nmap -A -Pn 192.168.146.151 直接访问web服务. 大概浏览一下没发现什么,直接扫描下目录把dirb+bp. BP具体 ...

  4. flask-url_for

    flask-url_for flask的url_for函数和django的reverse函数类似,都是提供视图反转url的方法 from flask import Flask, url_for app ...

  5. CVPR 2020论文收藏(转知乎:https://zhuanlan.zhihu.com/p/112337176)

    CVPR 2020 共收录 1470篇文章,根据当前的公布情况,人工智能学社整理了以下约100篇,分享给读者. 代码开源情况:详见每篇注释,当前共15篇开源.(持续更新中,可关注了解). 算法主要领域 ...

  6. Linux 磁盘管理篇,开机挂载

    设置开机挂载需要到 /etc/fstab 里设置 第一列:磁盘设备文件名或该设备的label 第二列:挂载点 第三列:磁盘分区文件系统 第四列:文件系统参数 第五列:能否被dump备份命令作用 第六列 ...

  7. 基础类封装-pymysql库操作mysql封装

    import pymysql from lib.logger import logger from warnings import filterwarnings filterwarnings(&quo ...

  8. vue(element)中使用codemirror实现代码高亮,代码补全,版本差异对比

    vue(element)中使用codemirror实现代码高亮,代码补全,版本差异对比 使用的是vue语言,用element的组件,要做一个在线编辑代码,要求输入代码内容,可以进行高亮展示,可以切换各 ...

  9. 文档根元素 "beans" 必须匹配 DOCTYPE 根 "null"

    文档根元素 "beans" 必须匹配 DOCTYPE 根 "null" (2011-11-20 21:26:41) 转载▼ 标签: 杂谈 分类: spring- ...

  10. 曹工说Redis源码(6)-- redis server 主循环大体流程解析

    文章导航 Redis源码系列的初衷,是帮助我们更好地理解Redis,更懂Redis,而怎么才能懂,光看是不够的,建议跟着下面的这一篇,把环境搭建起来,后续可以自己阅读源码,或者跟着我这边一起阅读.由于 ...