零基础学习人工智能—Python—Pytorch学习(九)
前言
本文主要介绍卷积神经网络的使用的下半部分。
另外,上篇文章增加了一点代码注释,主要是解释(w-f+2p)/s+1这个公式的使用。
所以,要是这篇文章的代码看不太懂,可以翻一下上篇文章。
代码实现
之前,我们已经学习了概念,在结合我们以前学习的知识,我们可以直接阅读下面代码了。
代码里使用了,dataset.CIFAR10数据集。
CIFAR-10 数据集由 60000 张 32x32 彩色图像组成,共分为 10 个不同的类别,分别是飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船和卡车。
每个类别包含 6000 张图像,其中 50000 张用于训练,10000 张用于测试。
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F #nn不好使时,在这里找激活函数
# device config
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# hyper parameters
input_size = 784 # 28x28
hidden_size = 100
num_classes = 10
batch_size = 100
learning_rate = 0.001
num_epochs = 2
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
train_dataset = torchvision.datasets.CIFAR10(
root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.CIFAR10(
root='./data', train=False, download=True, transform=transform)
train_loader = torch.utils. data.DataLoader(
dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
dataset=test_dataset, batch_size=batch_size, shuffle=False)
print('每份100个,被分成多少份:', len(test_loader))
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
class ConvNet(nn.Module):
def __init__(self):
super(ConvNet,self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16*5*5, 120) #这个在forward里解释
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x))) #这里x已经变成 torch.Size([4, 16, 5, 5])
# print("两次卷积两次池化后的x.shape:",x.shape)
x = x.view(-1,16*5*5)#这里的16*5*5就是x的后面3个维度相乘
x = F.relu(self.fc1(x)) #fc1定义时,inputx已经是6*5*5了
x = F.relu(self.fc2(x))
x= self.fc3(x)
return x
model = ConvNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
n_total_steps = len(train_loader)
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
# origin shape:[4,3,32,32]=4,3,1024
# input layer: 3 input channels, 6 output channels, 5 kernel size
images = images.to(device)
labels = labels.to(device)
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels)
# Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i+1) % 2000 == 0:
print(
f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss: {loss.item():.4f}')
print('Finished Training')
# test
with torch.no_grad():
n_correct = 0
n_samples = 0
n_class_correct = [0 for i in range(10)] #生成 10 个 0 的列表
n_class_samples = [0 for i in range(10)]
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
print('test-images.shape:', images.shape)
outputs = model(images)
# max returns(value ,index)
_, predicted = torch.max(outputs, 1)
n_samples += labels.size(0)
n_correct += (predicted == labels).sum().item()
for i in range(batch_size):
label = labels[i]
# print("label:",label) #这里存的是 0~9的数字 输出就是这样的 label: tensor(2) predicted[i]也是这样的数
pred = predicted[i]
if (label == pred):
n_class_correct[label] += 1
n_class_samples[label] += 1
acc = 100.0*n_correct/n_samples # 计算正确率
print(f'accuracy ={acc}')
for i in range(10):
acc = 100.0*n_class_correct[i]/n_class_samples[i]
print(f'Accuracy of {classes[i]}: {acc} %')
运行结果如下:
accuracy =10.26
Accuracy of plane: 0.0 %
Accuracy of car: 0.0 %
Accuracy of bird: 0.0 %
Accuracy of cat: 0.0 %
Accuracy of deer: 0.0 %
Accuracy of dog: 0.0 %
Accuracy of frog: 0.0 %
Accuracy of horse: 0.0 %
Accuracy of ship: 89.6 %
Accuracy of truck: 13.0 %
这是因为我设置的num_epochs=2,也就是循环的次数太低,所以结果的精确度就很低。
我们只要增加epochs的值,就能提高精确度了。
传送门:
零基础学习人工智能—Python—Pytorch学习—全集
这样我们卷积神经网络就学完了。
注:此文章为原创,任何形式的转载都请联系作者获得授权并注明出处!

若您觉得这篇文章还不错,请点击下方的【推荐】,非常感谢!
https://www.cnblogs.com/kiba/p/18381036
零基础学习人工智能—Python—Pytorch学习(九)的更多相关文章
- 如何零基础开始自学Python编程
转载——原作者:赛门喵 链接:https://www.zhihu.com/question/29138020/answer/141170242 0. 明确目标 我是真正零基础开始学Python的,从一 ...
- 零基础快速掌握Python系统管理视频课程【猎豹网校】
点击了解更多Python课程>>> 零基础快速掌握Python系统管理视频课程[猎豹网校] 课程目录 01.第01章 Python简介.mp4 02.第02章 IPython基础.m ...
- 零基础的人该怎么学习JAVA
对于JAVA有所兴趣但又是零基础的人,该如何学习JAVA呢?对于想要学习开发技术的学子来说找到一个合适自己的培训机构是非常难的事情,在选择的过程中总是 因为这样或那样的问题让你犹豫不决,阻碍你前进的 ...
- 零基础学完Python的7大就业方向,哪个赚钱多?
“ 我想学 Python,但是学完 Python 后都能干啥 ?” “ 现在学 Python,哪个方向最简单?哪个方向最吃香 ?” “ …… ” 相信不少 Python 的初学者,都会遇到上面的这些问 ...
- 零基础怎么学Python编程,新手常犯哪些错误?
Python是人工智能时代最佳的编程语言,入门简单.功能强大,深获初学者的喜爱. 很多零基础学习Python开发的人都会忽视一些小细节,进而导致整个程序出现错误.下面就给大家介绍一下Python开发者 ...
- 零基础如何入门Python
编程零基础如何学习Python 如果你是零基础,注意是零基础,想入门编程的话,我推荐你学Python.虽然国内基本上是以C语言作为入门教学,但在麻省理工等国外大学都是以Python作为编程入门教学的. ...
- 零基础如何学Python爬虫技术?
在作者学习的众多编程技能中,爬虫技能无疑是最让作者着迷的.与自己闭关造轮子不同,爬虫的感觉是与别人博弈,一个在不停的构建 反爬虫 规则,一个在不停的破译规则. 如何入门爬虫?零基础如何学爬虫技术?那前 ...
- 零基础自学人工智能,看这些资料就够了(300G资料免费送)
为什么有今天这篇? 首先,标题不要太相信,哈哈哈. 本公众号之前已经就人工智能学习的路径.学习方法.经典学习视频等做过完整说明.但是鉴于每个人的基础不同,可能需要额外的学习资料进行辅助.特此,向大家免 ...
- 零基础自学用Python 3开发网络爬虫
原文出处: Jecvay Notes (@Jecvay) 由于本学期好多神都选了Cisco网络课, 而我这等弱渣没选, 去蹭了一节发现讲的内容虽然我不懂但是还是无爱. 我想既然都本科就出来工作还是按照 ...
- 零基础如何使用python处理字符串?
摘要:Python的普遍使用场景是自动化测试.爬取网页数据.科学分析之类,这其中都涉及到了对数据的处理,而数据的表现形式很多,今天我们来讲讲字符串的操作. 字符串是作为任意一门编程语言的基础,在P ...
随机推荐
- VBA-合并多个工作簿
'合并多个工作薄,并以工作薄的名字给sheet表命名(每个工作薄只有一张表) Sub test() Dim str As String Dim wb As Workbook str = Dir(&qu ...
- javaApi,mapreduce,awk,scala四种方式实现词频统计
awk方式实现词频统计: 方式一: vi wordcount.awk { for (i = 1; i <=NF;i++) //NF 表示的是浏览记录的域的个数 freq[$i]++ } END{ ...
- P9576 题解
赛时没仔细想,赛后才发现并不难. 将 \(l,r\) 与 \(l',r'\) 是否相交分开讨论. 假若不相交,那么 \(l',r' < l\) 或者 \(l',r' > r\) 并且 \( ...
- ajax过程?
1. 创建ajax对象var xhr = new XMLHttpRequest(); 2.告诉Ajax对象要向哪发送请求,以什么方式发送 //请求方式 请求地址xhr.open('get' ...
- [oeasy]python0018_ ASCII_字符分布_数字_大小写字母_符号_黑暗森林
打包和解包 回忆上次内容 decode 就是解码 解码和编码可以转化 encode 编码 decode 解码 互为逆过程 大小写字母之间序号全都相差(32)10进制 编辑 这是 ...
- gitbook 入门教程之比较代码块差异 diff 插件
在 markdown 文档中显示代码之间的差异的 Gitbook 插件 English | 中文 主页 Github : https://snowdreams1006.github.io/gitboo ...
- 如何免费提取PDF里的图片-pdfimages使用教程
写在前面 本随笔是非常菜的菜鸡写的.如有问题请及时提出. 可以联系:1160712160@qq.com GitHhub:https://github.com/WindDevil (目前啥也没有 动机 ...
- pandas无法打开.xlsx文件,xlrd.biffh.XLRDError: Excel xlsx file; not supported
原因是最近xlrd更新到了2.0.1版本,只支持.xls文件.所以pandas.read_excel('xxx.xlsx')会报错. 可以安装旧版xlrd,在cmd中运行: pip uninstall ...
- Java编程指南:高级技巧解析 - Excel单元格样式的编程设置
最新技术资源(建议收藏) https://www.grapecity.com.cn/resources/ 前言 在Java开发中,处理Excel文件是一项常见的任务.在处理Excel文件时,经常需要对 ...
- 假期小结7爬虫学习requests
这周我初步学习了py爬虫的相关知识,以下是我的部分总结 URL headers(URL头部)是HTTP请求中包含的一部分信息,用于描述.控制和传递请求的各种元数据.它们是位于HTTP请求消息的起始部分 ...