前言

本文主要介绍卷积神经网络的使用的下半部分。

另外,上篇文章增加了一点代码注释,主要是解释(w-f+2p)/s+1这个公式的使用。

所以,要是这篇文章的代码看不太懂,可以翻一下上篇文章。

代码实现

之前,我们已经学习了概念,在结合我们以前学习的知识,我们可以直接阅读下面代码了。

代码里使用了,dataset.CIFAR10数据集。

CIFAR-10 数据集由 60000 张 32x32 彩色图像组成,共分为 10 个不同的类别,分别是飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船和卡车。

每个类别包含 6000 张图像,其中 50000 张用于训练,10000 张用于测试。

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F #nn不好使时,在这里找激活函数
# device config
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# hyper parameters
input_size = 784 # 28x28
hidden_size = 100
num_classes = 10
batch_size = 100
learning_rate = 0.001
num_epochs = 2 transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) train_dataset = torchvision.datasets.CIFAR10(
root='./data', train=True, download=True, transform=transform) test_dataset = torchvision.datasets.CIFAR10(
root='./data', train=False, download=True, transform=transform) train_loader = torch.utils. data.DataLoader(
dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
dataset=test_dataset, batch_size=batch_size, shuffle=False)
print('每份100个,被分成多少份:', len(test_loader)) classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck') class ConvNet(nn.Module):
def __init__(self):
super(ConvNet,self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16*5*5, 120) #这个在forward里解释
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10) def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x))) #这里x已经变成 torch.Size([4, 16, 5, 5])
# print("两次卷积两次池化后的x.shape:",x.shape)
x = x.view(-1,16*5*5)#这里的16*5*5就是x的后面3个维度相乘
x = F.relu(self.fc1(x)) #fc1定义时,inputx已经是6*5*5了
x = F.relu(self.fc2(x))
x= self.fc3(x)
return x model = ConvNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) n_total_steps = len(train_loader) for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
# origin shape:[4,3,32,32]=4,3,1024
# input layer: 3 input channels, 6 output channels, 5 kernel size
images = images.to(device)
labels = labels.to(device)
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels) # Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i+1) % 2000 == 0:
print(
f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss: {loss.item():.4f}')
print('Finished Training') # test
with torch.no_grad():
n_correct = 0
n_samples = 0
n_class_correct = [0 for i in range(10)] #生成 10 个 0 的列表
n_class_samples = [0 for i in range(10)]
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
print('test-images.shape:', images.shape)
outputs = model(images)
# max returns(value ,index)
_, predicted = torch.max(outputs, 1)
n_samples += labels.size(0)
n_correct += (predicted == labels).sum().item()
for i in range(batch_size):
label = labels[i]
# print("label:",label) #这里存的是 0~9的数字 输出就是这样的 label: tensor(2) predicted[i]也是这样的数
pred = predicted[i]
if (label == pred):
n_class_correct[label] += 1
n_class_samples[label] += 1
acc = 100.0*n_correct/n_samples # 计算正确率
print(f'accuracy ={acc}') for i in range(10):
acc = 100.0*n_class_correct[i]/n_class_samples[i]
print(f'Accuracy of {classes[i]}: {acc} %')

运行结果如下:

accuracy =10.26
Accuracy of plane: 0.0 %
Accuracy of car: 0.0 %
Accuracy of bird: 0.0 %
Accuracy of cat: 0.0 %
Accuracy of deer: 0.0 %
Accuracy of dog: 0.0 %
Accuracy of frog: 0.0 %
Accuracy of horse: 0.0 %
Accuracy of ship: 89.6 %
Accuracy of truck: 13.0 %

这是因为我设置的num_epochs=2,也就是循环的次数太低,所以结果的精确度就很低。

我们只要增加epochs的值,就能提高精确度了。


传送门:

零基础学习人工智能—Python—Pytorch学习—全集

这样我们卷积神经网络就学完了。


注:此文章为原创,任何形式的转载都请联系作者获得授权并注明出处!



若您觉得这篇文章还不错,请点击下方的【推荐】,非常感谢!

https://www.cnblogs.com/kiba/p/18381036

零基础学习人工智能—Python—Pytorch学习(九)的更多相关文章

  1. 如何零基础开始自学Python编程

    转载——原作者:赛门喵 链接:https://www.zhihu.com/question/29138020/answer/141170242 0. 明确目标 我是真正零基础开始学Python的,从一 ...

  2. 零基础快速掌握Python系统管理视频课程【猎豹网校】

    点击了解更多Python课程>>> 零基础快速掌握Python系统管理视频课程[猎豹网校] 课程目录 01.第01章 Python简介.mp4 02.第02章 IPython基础.m ...

  3. 零基础的人该怎么学习JAVA

    对于JAVA有所兴趣但又是零基础的人,该如何学习JAVA呢?对于想要学习开发技术的学子来说找到一个合适自己的培训机构是非常难的事情,在选择的过程中总是  因为这样或那样的问题让你犹豫不决,阻碍你前进的 ...

  4. 零基础学完Python的7大就业方向,哪个赚钱多?

    “ 我想学 Python,但是学完 Python 后都能干啥 ?” “ 现在学 Python,哪个方向最简单?哪个方向最吃香 ?” “ …… ” 相信不少 Python 的初学者,都会遇到上面的这些问 ...

  5. 零基础怎么学Python编程,新手常犯哪些错误?

    Python是人工智能时代最佳的编程语言,入门简单.功能强大,深获初学者的喜爱. 很多零基础学习Python开发的人都会忽视一些小细节,进而导致整个程序出现错误.下面就给大家介绍一下Python开发者 ...

  6. 零基础如何入门Python

    编程零基础如何学习Python 如果你是零基础,注意是零基础,想入门编程的话,我推荐你学Python.虽然国内基本上是以C语言作为入门教学,但在麻省理工等国外大学都是以Python作为编程入门教学的. ...

  7. 零基础如何学Python爬虫技术?

    在作者学习的众多编程技能中,爬虫技能无疑是最让作者着迷的.与自己闭关造轮子不同,爬虫的感觉是与别人博弈,一个在不停的构建 反爬虫 规则,一个在不停的破译规则. 如何入门爬虫?零基础如何学爬虫技术?那前 ...

  8. 零基础自学人工智能,看这些资料就够了(300G资料免费送)

    为什么有今天这篇? 首先,标题不要太相信,哈哈哈. 本公众号之前已经就人工智能学习的路径.学习方法.经典学习视频等做过完整说明.但是鉴于每个人的基础不同,可能需要额外的学习资料进行辅助.特此,向大家免 ...

  9. 零基础自学用Python 3开发网络爬虫

    原文出处: Jecvay Notes (@Jecvay) 由于本学期好多神都选了Cisco网络课, 而我这等弱渣没选, 去蹭了一节发现讲的内容虽然我不懂但是还是无爱. 我想既然都本科就出来工作还是按照 ...

  10. 零基础如何使用python处理字符串?

    摘要:Python的普遍使用场景是自动化测试.爬取网页数据.科学分析之类,这其中都涉及到了对数据的处理,而数据的表现形式很多,今天我们来讲讲字符串的操作.   字符串是作为任意一门编程语言的基础,在P ...

随机推荐

  1. uboot load address、entry point、 bootm address以及kernel运行地址的意义及联系

    按各地址起作用的顺序,uboot引导linux内核启动涉及到以下地址: load address: entry point: 这两个地址是mkimage时指定的 bootm address:bootm ...

  2. P9358 题解

    不难发现,最开始有 \(n\) 条链,并且由于每个点最多有一个桥,所以我们的交换操作实际上等价于将相邻的两条链断开,然后将它们后半部分交换.并且每个点在路径中的相对位置不变. 于是考虑维护这些链. 有 ...

  3. Java 自定义注解校验字段唯一性

    业务场景 在项目中,某些情景下我们需要验证编码是否重复,账号是否重复,身份证号是否重复等... 那么有没有办法可以解决这类似的重复代码量呢? 我们可以通过自定义注解校验的方式去实现,在实体类上面加上自 ...

  4. 直接给一个数组项赋值,Vue 能检测到变化吗?

    由于 JavaScript 的限制,Vue 不能检测到以下数组的变动: 当你利用索引直接设置一个数组项时,例如: vm.items[indexOfItem] = newValue 当你修改数组的长度时 ...

  5. 【Azure Developer】一个复制Redis Key到另一个Redis服务的工具(redis_copy_net8)

    介绍一个简单的工具,用于将Redis数据从一个redis端点复制到另一个redis端点,基于原始存储库转换为.NET 8:https://github.com/LuBu0505/redis-copy- ...

  6. Happus:给准备离职成为独立开发者的你 5 点建议

    名字:Happus 开发者 / 团队:Regina Dan 平台:iOS, visionOS 请简要介绍下这款产品 Happus 是你追寻幸福健康关系.甚至提高婚姻生活品质的贴心助手.无论是关系维系. ...

  7. Three光源Target位置改变光照方向不变的问题及解决方法

    0x00 楔子 在 Three.js 中,光源的目标(target)是一种用于指定光源方向的重要元素.在聚光灯中和定向光(DirectionalLight)中都有用到. 有时我们可能会遇到光源目标位置 ...

  8. Day 7 - 哈希与 KMP

    字符串哈希 定义 我们定义一个把字符串映射到整数的函数 \(f\),这个 \(f\) 称为是 \(\text{Hash}\) 函数. 我们希望这个函数 \(f\) 可以方便地帮我们判断两个字符串是否相 ...

  9. CF709B 题解

    洛谷链接&CF 链接 本篇题解为此题较简单做法及较少码量,并且码风优良,请放心阅读. 题目简述 给定 \(N\) 个点,在一条数轴上,位置为 \(x_1,-,x_n\),你的位置为 \(p\) ...

  10. [rCore学习笔记 02]Ubuntu 22虚拟机安装

    写在前面 本随笔是非常菜的菜鸡写的.如有问题请及时提出. 可以联系:1160712160@qq.com GitHhub:https://github.com/WindDevil (目前啥也没有 Ubu ...