零基础学习人工智能—Python—Pytorch学习(九)
前言
本文主要介绍卷积神经网络的使用的下半部分。
另外,上篇文章增加了一点代码注释,主要是解释(w-f+2p)/s+1这个公式的使用。
所以,要是这篇文章的代码看不太懂,可以翻一下上篇文章。
代码实现
之前,我们已经学习了概念,在结合我们以前学习的知识,我们可以直接阅读下面代码了。
代码里使用了,dataset.CIFAR10数据集。
CIFAR-10 数据集由 60000 张 32x32 彩色图像组成,共分为 10 个不同的类别,分别是飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船和卡车。
每个类别包含 6000 张图像,其中 50000 张用于训练,10000 张用于测试。
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F #nn不好使时,在这里找激活函数
# device config
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# hyper parameters
input_size = 784 # 28x28
hidden_size = 100
num_classes = 10
batch_size = 100
learning_rate = 0.001
num_epochs = 2
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_dataset = torchvision.datasets.CIFAR10(
root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(
root='./data', train=False, download=True, transform=transform)
train_loader = torch.utils. data.DataLoader(
dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
dataset=test_dataset, batch_size=batch_size, shuffle=False)
print('每份100个,被分成多少份:', len(test_loader))
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
class ConvNet(nn.Module):
def __init__(self):
super(ConvNet,self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16*5*5, 120) #这个在forward里解释
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x))) #这里x已经变成 torch.Size([4, 16, 5, 5])
# print("两次卷积两次池化后的x.shape:",x.shape)
x = x.view(-1,16*5*5)#这里的16*5*5就是x的后面3个维度相乘
x = F.relu(self.fc1(x)) #fc1定义时,inputx已经是6*5*5了
x = F.relu(self.fc2(x))
x= self.fc3(x)
return x
model = ConvNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
n_total_steps = len(train_loader)
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
# origin shape:[4,3,32,32]=4,3,1024
# input layer: 3 input channels, 6 output channels, 5 kernel size
images = images.to(device)
labels = labels.to(device)
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i+1) % 2000 == 0:
print(
f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss: {loss.item():.4f}')
print('Finished Training')
# test
with torch.no_grad():
n_correct = 0
n_samples = 0
n_class_correct = [0 for i in range(10)] #生成 10 个 0 的列表
n_class_samples = [0 for i in range(10)]
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
print('test-images.shape:', images.shape)
outputs = model(images)
# max returns(value ,index)
_, predicted = torch.max(outputs, 1)
n_samples += labels.size(0)
n_correct += (predicted == labels).sum().item()
for i in range(batch_size):
label = labels[i]
# print("label:",label) #这里存的是 0~9的数字 输出就是这样的 label: tensor(2) predicted[i]也是这样的数
pred = predicted[i]
if (label == pred):
n_class_correct[label] += 1
n_class_samples[label] += 1
acc = 100.0*n_correct/n_samples # 计算正确率
print(f'accuracy ={acc}')
for i in range(10):
acc = 100.0*n_class_correct[i]/n_class_samples[i]
print(f'Accuracy of {classes[i]}: {acc} %')
运行结果如下:
accuracy =10.26
Accuracy of plane: 0.0 %
Accuracy of car: 0.0 %
Accuracy of bird: 0.0 %
Accuracy of cat: 0.0 %
Accuracy of deer: 0.0 %
Accuracy of dog: 0.0 %
Accuracy of frog: 0.0 %
Accuracy of horse: 0.0 %
Accuracy of ship: 89.6 %
Accuracy of truck: 13.0 %
这是因为我设置的num_epochs=2,也就是循环的次数太低,所以结果的精确度就很低。
我们只要增加epochs的值,就能提高精确度了。
传送门:
零基础学习人工智能—Python—Pytorch学习—全集
这样我们卷积神经网络就学完了。
注:此文章为原创,任何形式的转载都请联系作者获得授权并注明出处!

若您觉得这篇文章还不错,请点击下方的【推荐】,非常感谢!
https://www.cnblogs.com/kiba/p/18381036
零基础学习人工智能—Python—Pytorch学习(九)的更多相关文章
- 如何零基础开始自学Python编程
转载——原作者:赛门喵 链接:https://www.zhihu.com/question/29138020/answer/141170242 0. 明确目标 我是真正零基础开始学Python的,从一 ...
- 零基础快速掌握Python系统管理视频课程【猎豹网校】
点击了解更多Python课程>>> 零基础快速掌握Python系统管理视频课程[猎豹网校] 课程目录 01.第01章 Python简介.mp4 02.第02章 IPython基础.m ...
- 零基础的人该怎么学习JAVA
对于JAVA有所兴趣但又是零基础的人,该如何学习JAVA呢?对于想要学习开发技术的学子来说找到一个合适自己的培训机构是非常难的事情,在选择的过程中总是 因为这样或那样的问题让你犹豫不决,阻碍你前进的 ...
- 零基础学完Python的7大就业方向,哪个赚钱多?
“ 我想学 Python,但是学完 Python 后都能干啥 ?” “ 现在学 Python,哪个方向最简单?哪个方向最吃香 ?” “ …… ” 相信不少 Python 的初学者,都会遇到上面的这些问 ...
- 零基础怎么学Python编程,新手常犯哪些错误?
Python是人工智能时代最佳的编程语言,入门简单.功能强大,深获初学者的喜爱. 很多零基础学习Python开发的人都会忽视一些小细节,进而导致整个程序出现错误.下面就给大家介绍一下Python开发者 ...
- 零基础如何入门Python
编程零基础如何学习Python 如果你是零基础,注意是零基础,想入门编程的话,我推荐你学Python.虽然国内基本上是以C语言作为入门教学,但在麻省理工等国外大学都是以Python作为编程入门教学的. ...
- 零基础如何学Python爬虫技术?
在作者学习的众多编程技能中,爬虫技能无疑是最让作者着迷的.与自己闭关造轮子不同,爬虫的感觉是与别人博弈,一个在不停的构建 反爬虫 规则,一个在不停的破译规则. 如何入门爬虫?零基础如何学爬虫技术?那前 ...
- 零基础自学人工智能,看这些资料就够了(300G资料免费送)
为什么有今天这篇? 首先,标题不要太相信,哈哈哈. 本公众号之前已经就人工智能学习的路径.学习方法.经典学习视频等做过完整说明.但是鉴于每个人的基础不同,可能需要额外的学习资料进行辅助.特此,向大家免 ...
- 零基础自学用Python 3开发网络爬虫
原文出处: Jecvay Notes (@Jecvay) 由于本学期好多神都选了Cisco网络课, 而我这等弱渣没选, 去蹭了一节发现讲的内容虽然我不懂但是还是无爱. 我想既然都本科就出来工作还是按照 ...
- 零基础如何使用python处理字符串?
摘要:Python的普遍使用场景是自动化测试.爬取网页数据.科学分析之类,这其中都涉及到了对数据的处理,而数据的表现形式很多,今天我们来讲讲字符串的操作. 字符串是作为任意一门编程语言的基础,在P ...
随机推荐
- todo高通Android UEFI中的LCD分析(1):启动流程分析
# 高通Android UEFI中的LCD分析(1):启动流程 背景 之前学习的lk阶段点亮LCD的流程算是比较经典,但是高通已经推出了很多种基于UEFI方案的启动架构. 所以需要对这块比较新的技术进 ...
- FFmpeg开发笔记(三十四)Linux环境给FFmpeg集成libsrt和librist
<FFmpeg开发实战:从零基础到短视频上线>一书的"10.2 FFmpeg推流和拉流"提到直播行业存在RTSP和RTMP两种常见的流媒体协议.除此以外,还有比较两 ...
- k8s实战 ---- pod 基础
如果你对k8s还不了解,可以看下前文 k8s 实战 1 ---- 初识 (https://www.cnblogs.com/jilodream/p/18245222) 什么是pod,pod在 ...
- 类、事件与对象---Dad&Mom&Friends(进阶事件)
接上一个笔记:https://www.cnblogs.com/StephenYoung/p/17792668.html 现在增加了一个新的朋友类:Friends 这个类构造如下: 从上到下依次是: 1 ...
- 转载 | win11右键菜单改为win10的bat命令(以及恢复方法bat)
原文来自这里:https://blog.51cto.com/knifeedge/5340751 版权归:IT利刃出鞘 本质上就是写入注册表. 一.右键菜单改回Win10(展开) 1. 新建文件:win ...
- oeasy教您玩转vim - 91 - # vim脚本编程展望
vim脚本编程展望 回忆 上次我们彻底研究了vim高亮的原理 各种语法项syntax item 关键字keyword 匹配模式match 区域region 定义好了之后还可以设置链接成组 hi d ...
- Linux 手工释放Linux Cache Memory
手工释放Linux Cache Memory 为了加速操作和减少磁盘I/O,内核通常会尽可能多地缓存内存,这部分内存就是Cache Memory(缓存内存).根据设计,包含缓存数据的页面可以按需重新用 ...
- pandas无法打开.xlsx文件,xlrd.biffh.XLRDError: Excel xlsx file; not supported
原因是最近xlrd更新到了2.0.1版本,只支持.xls文件.所以pandas.read_excel('xxx.xlsx')会报错. 可以安装旧版xlrd,在cmd中运行: pip uninstall ...
- Cython与C函数的结合
技术背景 在前面一篇博客中,我们介绍了使用Cython加速谐振势计算的方法.有了Cython对于计算过程更加灵活的配置(本质上是时间占用和空间占用的一种均衡),及其接近于C的性能,并且还最大程度上的保 ...
- golang 学习笔记1
1.go的gin框架,没有预设目录,具体目录可以在网上参考.