AI应用实战课学习总结(10)用CNN做图像分类
大家好,我是Edison。
最近入坑黄佳老师的《AI应用实战课》,记录下我的学习之旅,也算是总结回顾。
今天是我们的第10站,一起了解CNN卷积神经网络 以及 通过CNN做图像分类任务的案例。
CNN卷积神经网络介绍
卷积神经网络(CNN)是一种用于图像识别和处理的人工神经网络,其灵感来自于动物视觉皮层的生物过程。它们由具有可学习权重和偏差的神经元组成。
CNN在至少一个层中使用一种称为卷积的技术,而不是一般的矩阵乘法,卷积是一种特殊的线性运算。
下图展示了一个典型的CNN架构:输入的是图像,输出的是图像的标签。例如下图中输入了一张卡通人物(崔弟鸟)的图片,输出的是几个可能得标签及其概率,其中AI认为Tweety(动画片 崔弟鸟的名字)的概率最高。

那么,从输入到输出之间都经历了什么呢?
输入层一般是图像,这里的图像通常来说图像的张量,它是神经网络能够读取的图片的结构。然后,通过卷积层(Convolution)做图像特征的提取(一般是局部特征),再通过池化(Pooling)降低特征空间的维度,然后继续多次卷积和池化,提取上一层中的特征图的特征,随诊深度网络的加深,特征也就越来越纯,会变得越来越抽象,但神经网络可以理解。最后,经历一个展平层(Flatten Layer)进入全连接层(Fully connected Layer)做一个Softmax激活(激活函数),完成分类输出,上图中输出了3个分类,所有分类的概率值加起来之和为0.7+0.2+0.1=1。
CIFAR-10数据集
接下来,我们要做一个基于CNN的图像分类的案例,那么,就需要一个输入的图片数据集。这里,我们了解一下CIFAR-10数据集,10代表10种常见物体,大概有6万张这10种物体的图片,这个数据集也常用于图像分类问题的教学任务。

这些图片全都是32*32的尺寸,类别包括:飞机、汽车、鸟、猫、鹿、狗、蛙、马、船、卡车,每个类别都有5000张训练图片和1000张测试图片。对于我们用PyTorch来做Demo来说,不需要我们自己将整个数据集手动下载下来并保存到某个目录,使用PyTorch提供的图像库函数会自动帮我们下载和加载到程序中,十分方便。当程序代码完成下载后,CIFAR-10数据集也就会保存到你当前应用程序的目录下:

需要注意的是,下载下来的文件目录中的内容并不是原始的一张张图片,而是已经转化为张量的适合PyTorch读取的格式。
基于CNN做图像分类案例
基线模型:ResNet-18
这里我们使用预训练好的ResNet-18模型作为预训练网络(或者说基线模型),它是一个典型的用于图像识别的CNN神经网络模型。它本身采用了ImageNet的大量图片做了训练,这里我们将其下载下来对我们的CIFAR10数据集做二次训练,也可以称为“迁移学习”。
对于深度学习来说,建议在GPU上进行训练,在CPU上训练会很慢很慢。
Step1. 导入所需要的库
# 1. 导入所需要的库
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
Step2. 下载CIFAR-10数据集
# 2. 下载CIFAR-10数据集
# 设置图像预处理: 图像增强 + 转换为张量 + 标准化
transform = transforms.Compose(
[transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) # 下载训练集和测试集
print("[LOG] Now loading CIFAR-10 dataset for Training...")
trainset = torchvision.datasets.CIFAR10(root='CIFAR10', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64,
shuffle=True, num_workers=2)
print("[LOG] Now loading CIFAR-10 dataset for Testing...")
testset = torchvision.datasets.CIFAR10(root='CIFAR10', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64,
shuffle=False, num_workers=2)
print("[LOG] Loading CIFAR-10 dataset finished.")
Step3. 下载预训练的ResNet-18模型
这里torchvision的models中已经带了resnet-18模型,用起来十分方便。
# 3. 使用ResNet-18作为预训练网络
# 下载预训练的ResNet-18模型
print("[LOG] Now loading model RestNet-18...")
resnet18 = torchvision.models.resnet18(pretrained=True)
print("[LOG] Loading model ResNet-18 finished.")
# 由于CIFAR-10有10个类,我们需要调整ResNet的最后一个全连接层
num_classes = 10
resnet18.fc = nn.Linear(resnet18.fc.in_features, num_classes)
需要注意的是:由于CIFAR-10有10个类别,因此需要调整一下基线模型的最后一个全连接层,将其num_classes改为10。
Step4. 微调预训练CNN网络
这里开始定义损失函数和优化器,然后一轮又一轮地训练这个网络,并打印出损失。
# 4. 微调预训练的CNN网络
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(resnet18.parameters(), lr=0.001, momentum=0.9)
# 迁移到GPU上(如果有的话)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("[LOG] Current running device: ", device)
resnet18.to(device)
# 训练网络
print("[LOG] Start training...")
for epoch in range(10): # 就演示训练10个epochs
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# 获取输入数据
inputs, labels = data[0].to(device), data[1].to(device)
# 清零参数梯度
optimizer.zero_grad()
# 前向 + 反向 + 优化
outputs = resnet18(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if i % 200 == 199: # 每200批次打印一次
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 200))
running_loss = 0.0
print('[LOG] Training finished')
Step5. 测试训练结果(网络性能)
最后,在测试集中进行测试,并打印出该网络的准确度。
# 5. 测试网络性能
correct = 0
total = 0
with torch.no_grad():
for data in testloader:
images, labels = data[0].to(device), data[1].to(device)
outputs = resnet18(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('[LOG] Accuracy of the network on the 10000 test images: %d %%' % (
100 * correct / total))
刚好我这里有一个GPU Runner跑了几分钟就出来了结果:

可以看到,我们借助ResNet-18在几分钟时间内就训练好了一个接近80%准确度的用于CIFAR-10数据集的模型,成本是真的很低,上手也是很快的。
当然,你也可以选择自己写一个CNN网络来做这个处理,但开发成本会高一些,而且效果也不一定有直接基于这些已有CNN网络模型做迁移的效果好。
小结
本文介绍了CNN的基本概念 以及 如何基于预训练的CNN模型对于CIFAR-10数据集做图像分类的案例。基于预训练好的CNN模型作为基线模型,针对你自己的图片数据集做二次训练(迁移学习),通常可以兼顾成本和性能,是值得采用的实践方式。
推荐学习
黄佳,《AI应用实战课》(课程)

黄佳,《图解GPT:大模型是如何构建的》(图书)
黄佳,《动手做AI Agent》(图书)

AI应用实战课学习总结(10)用CNN做图像分类的更多相关文章
- DDD实战课--学习笔记
目录 学好了DDD,你能做什么? 领域驱动设计:微服务设计为什么要选择DDD? 领域.子域.核心域.通用域和支撑域:傻傻分不清? 限界上下文:定义领域边界的利器 实体和值对象:从领域模型的基础单元看系 ...
- 《即时消息技术剖析与实战》学习笔记10——IM系统如何应对高并发
一.IM 系统的高并发场景 IM 系统中,高并发多见于直播互动场景.比如直播间,在直播过程中,观众会给主播打赏.送礼.发送弹幕等,尤其是明星直播间,几十万.上百万人的规模一点也不稀奇.近期随着武汉新型 ...
- 《Angular4从入门到实战》学习笔记
<Angular4从入门到实战>学习笔记 腾讯课堂:米斯特吴 视频讲座 二〇一九年二月十三日星期三14时14分 What Is Angular?(简介) 前端最流行的主流JavaScrip ...
- [AI开发]将深度学习技术应用到实际项目
本文介绍如何将基于深度学习的目标检测算法应用到具体的项目开发中,体现深度学习技术在实际生产中的价值,算是AI算法的一个落地实现.本文算法部分可以参见前面几篇博客: [AI开发]Python+Tenso ...
- Python第十课学习
Python第十课学习 www.cnblogs.com/yuanchenqi/articles/5828233.html 函数: 1 减少代码的重复 2 更易扩展,弹性更强:便于日后文件功能的修改 3 ...
- golang学习笔记10 beego api 用jwt验证auth2 token 获取解码信息
golang学习笔记10 beego api 用jwt验证auth2 token 获取解码信息 Json web token (JWT), 是为了在网络应用环境间传递声明而执行的一种基于JSON的开放 ...
- AI面试必备/深度学习100问1-50题答案解析
AI面试必备/深度学习100问1-50题答案解析 2018年09月04日 15:42:07 刀客123 阅读数 2020更多 分类专栏: 机器学习 转载:https://blog.csdn.net ...
- 实战 迁移学习 VGG19、ResNet50、InceptionV3 实践 猫狗大战 问题
实战 迁移学习 VGG19.ResNet50.InceptionV3 实践 猫狗大战 问题 参考博客:::https://blog.csdn.net/pengdali/article/detail ...
- JavaEE精英进阶课学习笔记《博学谷》
JavaEE精英进阶课学习笔记<博学谷> 第1章 亿可控系统分析与设计 学习目标 了解物联网应用领域及发展现状 能够说出亿可控的核心功能 能够画出亿可控的系统架构图 能够完成亿可控环境的准 ...
- 《机器学习实战》学习笔记——第2章 KNN
一. KNN原理: 1. 有监督的学习 根据已知事例及其类标,对新的实例按照离他最近的K的邻居中出现频率最高的类别进行分类.伪代码如下: 1)计算已知类别数据集中的点与当前点之间的距离 2)按照距离从 ...
随机推荐
- UML中的各种关系
各种关系 UML中的各种关系一览表 名称 英文名称 符号 描述 实现方法 耦合强度 举例 关键词 备注 依赖 dependency 1.当类与类之间有使用关系时就属于依赖关系:2.依赖不具有" ...
- final关键字、Object类--java进阶day01
1.规则 被final修饰的变量,名称都要大写,多单词的名称则需_来分隔 1.修饰方法 method方法已经不能被重写了,因为修饰该方法的是final 2.修饰类 当一个类中所有的成员方法都不想被重写 ...
- 【Linux】Vim 设置
[Linux]Vim 设置 零.起因 刚学Linux,有时候会重装Linux系统,然后默认的vi不太好用,需要进行一些设置,本文简述如何配置一个好用的Vim. 壹.软件安装 sudo apt-get ...
- zookeeper选主机制
Zookeeper选主机制 一.Server工作状态 每个Server在工作过程中有四种状态: LOOKING:竞选状态,当前Server不知道leader是谁,正在搜寻. LEADING:领导者状态 ...
- Python 潮流周刊#98:t-string 语法被正式接纳了(摘要)
本周刊由 Python猫 出品,精心筛选国内外的 250+ 信息源,为你挑选最值得分享的文章.教程.开源项目.软件工具.播客和视频.热门话题等内容.愿景:帮助所有读者精进 Python 技术,并增长职 ...
- 《基于改进Wallace树的Posit乘法单元优化》(一)
原文 文章通过增加特定的计数器.重新设计部分积求和阶段计数器布局 以及改进最终求和阶段使用的加法器,提出一种名为3L-Wallace树的改进Wallace树算法,有效减少了部分积求和的阶段数, 从而降 ...
- Git错误,fatal: refusing to merge unrelated histories
错误:fatal: refusing to merge unrelated histories 中文意思就是拒绝合并不相关的历史, 解决 出现这个问题的最主要原因还是在于本地仓库和远程仓库实际上是独立 ...
- 『Plotly实战指南』--在金融数据可视化中的应用(下)
在金融市场的复杂博弈中,可视化技术如同精密的导航仪. 传统静态图表正在被交互式可视化取代--据Gartner研究,采用动态可视化的投资机构决策效率提升达47%. 本文的目标是探讨如何利用 Plotly ...
- 【HUST】网安|操作系统实验|实验二 进程管理与死锁
目的 1)理解进程/线程的概念和应用编程过程: 2)理解进程/线程的同步机制和应用编程: 任务 1)在Linux下创建一对父子进程. 2)在Linux下创建2个线程A和B,循环输出数据或字符串. 3) ...
- django笔记(3)-数据库操作
一:路由系统 url 1.url(r'^index/', views.index),url(r'^home/',views.Home.as_view()), 一个url对应一个函数或一个类 ...