前言

本文主要介绍卷积神经网络的使用的下半部分。

另外,上篇文章增加了一点代码注释,主要是解释(w-f+2p)/s+1这个公式的使用。

所以,要是这篇文章的代码看不太懂,可以翻一下上篇文章。

代码实现

之前,我们已经学习了概念,在结合我们以前学习的知识,我们可以直接阅读下面代码了。

代码里使用了,dataset.CIFAR10数据集。

CIFAR-10 数据集由 60000 张 32x32 彩色图像组成,共分为 10 个不同的类别,分别是飞机、汽车、鸟、猫、鹿、狗、青蛙、马、船和卡车。

每个类别包含 6000 张图像,其中 50000 张用于训练,10000 张用于测试。

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import torch.nn.functional as F #nn不好使时,在这里找激活函数
# device config
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# hyper parameters
input_size = 784 # 28x28
hidden_size = 100
num_classes = 10
batch_size = 100
learning_rate = 0.001
num_epochs = 2 transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) train_dataset = torchvision.datasets.CIFAR10(
root='./data', train=True, download=True, transform=transform) test_dataset = torchvision.datasets.CIFAR10(
root='./data', train=False, download=True, transform=transform) train_loader = torch.utils. data.DataLoader(
dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
dataset=test_dataset, batch_size=batch_size, shuffle=False)
print('每份100个,被分成多少份:', len(test_loader)) classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck') class ConvNet(nn.Module):
def __init__(self):
super(ConvNet,self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16*5*5, 120) #这个在forward里解释
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10) def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x))) #这里x已经变成 torch.Size([4, 16, 5, 5])
# print("两次卷积两次池化后的x.shape:",x.shape)
x = x.view(-1,16*5*5)#这里的16*5*5就是x的后面3个维度相乘
x = F.relu(self.fc1(x)) #fc1定义时,inputx已经是6*5*5了
x = F.relu(self.fc2(x))
x= self.fc3(x)
return x model = ConvNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) n_total_steps = len(train_loader) for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
# origin shape:[4,3,32,32]=4,3,1024
# input layer: 3 input channels, 6 output channels, 5 kernel size
images = images.to(device)
labels = labels.to(device)
# Forward pass
outputs = model(images)
loss = criterion(outputs, labels) # Backward and optimize
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i+1) % 2000 == 0:
print(
f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss: {loss.item():.4f}')
print('Finished Training') # test
with torch.no_grad():
n_correct = 0
n_samples = 0
n_class_correct = [0 for i in range(10)] #生成 10 个 0 的列表
n_class_samples = [0 for i in range(10)]
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
print('test-images.shape:', images.shape)
outputs = model(images)
# max returns(value ,index)
_, predicted = torch.max(outputs, 1)
n_samples += labels.size(0)
n_correct += (predicted == labels).sum().item()
for i in range(batch_size):
label = labels[i]
# print("label:",label) #这里存的是 0~9的数字 输出就是这样的 label: tensor(2) predicted[i]也是这样的数
pred = predicted[i]
if (label == pred):
n_class_correct[label] += 1
n_class_samples[label] += 1
acc = 100.0*n_correct/n_samples # 计算正确率
print(f'accuracy ={acc}') for i in range(10):
acc = 100.0*n_class_correct[i]/n_class_samples[i]
print(f'Accuracy of {classes[i]}: {acc} %')

运行结果如下:

accuracy =10.26
Accuracy of plane: 0.0 %
Accuracy of car: 0.0 %
Accuracy of bird: 0.0 %
Accuracy of cat: 0.0 %
Accuracy of deer: 0.0 %
Accuracy of dog: 0.0 %
Accuracy of frog: 0.0 %
Accuracy of horse: 0.0 %
Accuracy of ship: 89.6 %
Accuracy of truck: 13.0 %

这是因为我设置的num_epochs=2,也就是循环的次数太低,所以结果的精确度就很低。

我们只要增加epochs的值,就能提高精确度了。


传送门:

零基础学习人工智能—Python—Pytorch学习—全集

这样我们卷积神经网络就学完了。


注:此文章为原创,任何形式的转载都请联系作者获得授权并注明出处!



若您觉得这篇文章还不错,请点击下方的【推荐】,非常感谢!

https://www.cnblogs.com/kiba/p/18381036

零基础学习人工智能—Python—Pytorch学习(九)的更多相关文章

  1. 如何零基础开始自学Python编程

    转载——原作者:赛门喵 链接:https://www.zhihu.com/question/29138020/answer/141170242 0. 明确目标 我是真正零基础开始学Python的,从一 ...

  2. 零基础快速掌握Python系统管理视频课程【猎豹网校】

    点击了解更多Python课程>>> 零基础快速掌握Python系统管理视频课程[猎豹网校] 课程目录 01.第01章 Python简介.mp4 02.第02章 IPython基础.m ...

  3. 零基础的人该怎么学习JAVA

    对于JAVA有所兴趣但又是零基础的人,该如何学习JAVA呢?对于想要学习开发技术的学子来说找到一个合适自己的培训机构是非常难的事情,在选择的过程中总是  因为这样或那样的问题让你犹豫不决,阻碍你前进的 ...

  4. 零基础学完Python的7大就业方向,哪个赚钱多?

    “ 我想学 Python,但是学完 Python 后都能干啥 ?” “ 现在学 Python,哪个方向最简单?哪个方向最吃香 ?” “ …… ” 相信不少 Python 的初学者,都会遇到上面的这些问 ...

  5. 零基础怎么学Python编程,新手常犯哪些错误?

    Python是人工智能时代最佳的编程语言,入门简单.功能强大,深获初学者的喜爱. 很多零基础学习Python开发的人都会忽视一些小细节,进而导致整个程序出现错误.下面就给大家介绍一下Python开发者 ...

  6. 零基础如何入门Python

    编程零基础如何学习Python 如果你是零基础,注意是零基础,想入门编程的话,我推荐你学Python.虽然国内基本上是以C语言作为入门教学,但在麻省理工等国外大学都是以Python作为编程入门教学的. ...

  7. 零基础如何学Python爬虫技术?

    在作者学习的众多编程技能中,爬虫技能无疑是最让作者着迷的.与自己闭关造轮子不同,爬虫的感觉是与别人博弈,一个在不停的构建 反爬虫 规则,一个在不停的破译规则. 如何入门爬虫?零基础如何学爬虫技术?那前 ...

  8. 零基础自学人工智能,看这些资料就够了(300G资料免费送)

    为什么有今天这篇? 首先,标题不要太相信,哈哈哈. 本公众号之前已经就人工智能学习的路径.学习方法.经典学习视频等做过完整说明.但是鉴于每个人的基础不同,可能需要额外的学习资料进行辅助.特此,向大家免 ...

  9. 零基础自学用Python 3开发网络爬虫

    原文出处: Jecvay Notes (@Jecvay) 由于本学期好多神都选了Cisco网络课, 而我这等弱渣没选, 去蹭了一节发现讲的内容虽然我不懂但是还是无爱. 我想既然都本科就出来工作还是按照 ...

  10. 零基础如何使用python处理字符串?

    摘要:Python的普遍使用场景是自动化测试.爬取网页数据.科学分析之类,这其中都涉及到了对数据的处理,而数据的表现形式很多,今天我们来讲讲字符串的操作.   字符串是作为任意一门编程语言的基础,在P ...

随机推荐

  1. P3806 题解

    看到现有的一篇 DSU on tree 的题解复杂度假了,于是我来再写一篇. 首先重新梳理思路,维护每棵子树内深度为某个值的节点是否存在. 维护这个东西可以直接 DSU on tree 也就是把小的子 ...

  2. spring cloud 上云的情况下,Ribbon 客户端负载均衡 与 ALB 服务端负载均衡的选择

    在云环境(例如AWS)中,由于云提供商通常提供强大的负载均衡服务(如AWS的ALB),一般不再需要使用Ribbon这种客户端负载均衡方案.云环境中的负载均衡器通常能够提供更高的可靠性.可扩展性和简化的 ...

  3. jsbarcode 生成条形码,并将生成的条码保存至本地,附源码

    导读 以前生成条码都是外网网站上生成,因生产环境在内网中,上不了外网,只能在项目中生成相应规则,故将此方法整理下来. html <!DOCTYPE html> <html> & ...

  4. Git 奇幻之旅⌛️

    第一天: 本地仓库 故事的主角是小明,一个刚入门编程的小白.他正在为一个项目写代码,但是他发现每次修改代码都很麻烦,因为他要不断地备份文件,而且很容易弄混版本.有一天,他听说了一个叫 Git 的神奇工 ...

  5. iOS开发基础102-后台保活方案

    iOS系统在后台执行程序时,有严格的限制,为了更好地管理资源和电池寿命,iOS会限制应用程序在后台的运行时间.然而,iOS提供了一些特定的策略和技术,使得应用程序可以在特定场景下保持后台运行(即&qu ...

  6. Linux-makefile命令后面的-j4 -j8是什么意思?

    其实是指在编译指定的文件时用多少个线程进行编程的意思~ 相关命令示例如下: make zImage -j8 make modules -j8 --------------------------- m ...

  7. Django配置为连接到Microsoft SQL Server

    可以将Django配置为连接到Microsoft SQL Server 2019.为此,你需要更改数据库设置中的一些配置选项.首先,确保你已经安装了 django 和 pyodbc 这两个库:   p ...

  8. JAVA私有构造函数---java笔记

    在Java中,构造函数是一种特殊的方法,它用于初始化新创建的对象.当我们创建一个类的实例时,构造函数会自动被调用. 构造函数可以有不同的访问修饰符,如public.protected.default( ...

  9. CF1929B Sasha and the Drawing 题解

    CF1929B 题意 给定一个 \(n\times n\) 的正方形,已知正方形最多有 \(4\times n-2\) 条对角线,要求要有至少 \(k\) 条对角线经过至少一块黑色方格,求至少要将几条 ...

  10. Jmeter参数化1-随机数设置

    背景:当新增接口的某个字段是唯一性,每次调用该新增接口都会需要单独传入这个字段,麻烦且繁琐. 解决:jmeter设置随机数参数,然后接口调用该参数就达到了自动性不再需要人工传入不同的值.方便调用接口, ...