零基础学习人工智能—Python—Pytorch学习(七)
前言
本文主要讲神经网络的下半部分。
其实就是结合之前学习的全部内容,进行一次神经网络的训练。
神经网络
下面是使用MNIST数据集进行的手写数字识别的神经网络训练和使用。
MNIST 数据集,是一个常用的手写数字识别数据集。MNIST 数据集包含 60,000 张 28x28 像素的灰度训练图像和 10,000 张测试图像,每张图像都表示一个手写的数字(0-9)。
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
# device config
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# hyper parameters
input_size = 784 # 28x28
hidden_size = 100
num_classes = 10
batch_size = 100
learning_rate = 0.001
# MNIST
# torchvision.datasets.MNIST: 这是一个用于加载MNIST数据集的类。 MNIST 数据集,它包含灰度的手写数字图像。每张图像的尺寸是 28x28 像素,灰度图像只有一个通道(channels=1)
# root='./data': root 参数指定了数据集的存储位置 './data' 表示一个相对路径,表示数据集将存储在当前工作目录下的 data 文件夹中。如果这个文件夹不存在,PyTorch 会自动创建它。
# train=True: 表示加载的是训练集数据。
# transform=transforms.ToTensor(): 将图像转换为PyTorch张量,并归一化为[0, 1]的范围。
# download=True: 如果指定的 root 路径下没有找到数据集,会自动从互联网下载MNIST数据集。
train_dataset = torchvision.datasets.MNIST(
root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_dataset = torchvision.datasets.MNIST(
root='./data', train=False, transform=transforms.ToTensor())
# torchvision.datasets.MNIST 是内置的数据集,所以不用去像之前内容中,要搞一个csv文件
# 这里直接把MNIST导入进DataLoader
# batch_size 指定了一次输入模型的数据量。指定batch_size为100,那就是一批次读取100个,利用数据集的索引就可以读取,因为下面还有个参数shuffle=True,所以读取的时候,数据是被打乱的。
train_loader = torch.utils. data.DataLoader(
dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(
dataset=test_dataset, batch_size=batch_size, shuffle=False)
print('每份100个,被分成多了份',len(train_loader))
examples = iter(train_loader) # 转换为迭代器,这样可以调用next,一行一行的取数据,只不过他这一行,是一组数据
samples, labels = examples.__next__() # 这里取出 x和y
print(samples.shape, labels.shape) # samples即x,是一个批次,即100个图像
# 这里输出的是torch.Size([100, 1, 28, 28]) torch.Size([100])
# 其中x是的数据维度是下面这样的。
# 第一个维度 (64): 表示批次中包含的样本数量,即 batch_size。在这个例子中,一次输入模型的有 100 张图像。
# 第二个维度 (1): 表示图像的通道数。对于灰度图像,通道数是 1,彩色图像则通常有 3 个通道(对应 RGB)。
# 第三个维度 (28): 表示图像的高度。MNIST 图像的高度为 28 像素。
# 第四个维度 (28): 表示图像的宽度。MNIST 图像的宽度也是 28 像素。
# y只有一个维度,就是100张图像
# x里都数据都是手写的数字,这里可以用图像把他们展示出来看一看
for i in range(6):
plt.subplot(2, 3, i+1) # 在图像窗口中创建一个 2 行 3 列的子图布局,并选择第 i+1 个子图位置。
plt.imshow(samples[i][0], cmap='gray')
# plt.show()
class NeuralNet(nn.Module):
def __init__(self, input_size, hidden_size, num_classes):
super(NeuralNet, self).__init__()
self.linear1 = nn.Linear(input_size, hidden_size)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(hidden_size, num_classes)
def forward(self, x):
out = self.linear1(x)
out = self.relu(out)
out = self.linear2(out)
# no softmax at the end
return out
model = NeuralNet(input_size=input_size,
hidden_size=hidden_size, num_classes=num_classes)
criterion = nn.CrossEntropyLoss() # (applies Softmax) 这里会调用激活函数,所以上面不调用激活函数了
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# training loop
n_total_steps = len(train_loader)
num_epochs = 2
#下面这个循环就走2次,意思是在训练完集合里的全部数据后,在重新来一遍
for epoch in range(num_epochs): #for——range模式=其他语言的for
#下面这个循环是训练集合里的全部数据
for i, (images, labels) in enumerate(train_loader): #for——enumerate模式=其他语言的foreach
# 这里的images是100个图像,也就是一个批次
# 将100,1,28,28 这个四维数组 转换成2维数组,转换结果应该是 100,784
# to(device) 是指将数据转移到这个设备上计算,如果有GPU,这个计算会被加速
images = images.reshape(-1, 28*28).to(device)
labels = labels.to(device)
# forward
outputs = model(images)
loss = criterion(outputs, labels)
# backwards
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i+1) % 100 == 0:
print(
f'epoch {epoch+1} / {num_epochs}, step {i+1}/{n_total_steps}, loss = {loss.item}')
# test
with torch.no_grad():
n_correct = 0
n_samples = 0
for images, labels in test_loader:
images = images.reshape(-1, 28*28).to(device) #转二维数组
labels = labels.to(device)
outputs = model(images) # 通过我们训练的模型,我们得到了y_predicted
# value,index
_, predictions = torch.max(outputs, 1) #torch.max(outputs, 1) 会在 outputs 的每一行(对应每个样本)中找到最大值及其索引。由于模型输出的是每个类别的概率分布,所以最大值的索引代表模型对该图像的预测类别。
n_samples += labels.shape[0] #labels.shape[0]会返回y的行数,就是100,因为一个批次100个图像
print("y行数",labels.shape[0])
#predictions == labels 会生成一个布尔张量(True 表示预测正确,False 表示预测错误)
#sum() 计算正确预测的数量并加到 n_correct 上
n_correct += (predictions == labels).sum().item()
acc = 100.0*n_correct/n_samples #计算正确率
print(f'accuracy ={acc}')
图形
现在我们学会了使用神经网络开发,我们在来看一些图形,就能看懂了。
比如这个M-P神经元模型。

在比如这个神经网络结构图。
下面粉色是输入层,绿色是隐藏层,蓝色是输出层。虽然下面画的隐藏层节点比输入层多,但实际情况并不一定,这只是个示意图,比如我们上面,输入的x是784列,隐藏层计算后,就剩100列。

结语
本质上我并不是python程序员,其实看我的注释就应该能感觉到吧,比如我对python的for循环都会加注释。
我之所以写这个系列,就是因为我不是python开发,这个系列是为了当我间隔超长时间重新使用python时,唤起死去的记忆用的。
不过,我感觉我写的顺序还不错,如果大家反复的仔细的阅读,应该也能掌握神经网络开发。
传送门:
零基础学习人工智能—Python—Pytorch学习(一)
零基础学习人工智能—Python—Pytorch学习(二)
零基础学习人工智能—Python—Pytorch学习(三)
零基础学习人工智能—Python—Pytorch学习(四)
零基础学习人工智能—Python—Pytorch学习(五)
零基础学习人工智能—Python—Pytorch学习(六)
零基础学习人工智能—Python—Pytorch学习(七)
注:此文章为原创,任何形式的转载都请联系作者获得授权并注明出处!

若您觉得这篇文章还不错,请点击下方的【推荐】,非常感谢!
https://www.cnblogs.com/kiba/p/18372411
零基础学习人工智能—Python—Pytorch学习(七)的更多相关文章
- 如何零基础开始自学Python编程
转载——原作者:赛门喵 链接:https://www.zhihu.com/question/29138020/answer/141170242 0. 明确目标 我是真正零基础开始学Python的,从一 ...
- 零基础快速掌握Python系统管理视频课程【猎豹网校】
点击了解更多Python课程>>> 零基础快速掌握Python系统管理视频课程[猎豹网校] 课程目录 01.第01章 Python简介.mp4 02.第02章 IPython基础.m ...
- 零基础的人该怎么学习JAVA
对于JAVA有所兴趣但又是零基础的人,该如何学习JAVA呢?对于想要学习开发技术的学子来说找到一个合适自己的培训机构是非常难的事情,在选择的过程中总是 因为这样或那样的问题让你犹豫不决,阻碍你前进的 ...
- 零基础学完Python的7大就业方向,哪个赚钱多?
“ 我想学 Python,但是学完 Python 后都能干啥 ?” “ 现在学 Python,哪个方向最简单?哪个方向最吃香 ?” “ …… ” 相信不少 Python 的初学者,都会遇到上面的这些问 ...
- 零基础怎么学Python编程,新手常犯哪些错误?
Python是人工智能时代最佳的编程语言,入门简单.功能强大,深获初学者的喜爱. 很多零基础学习Python开发的人都会忽视一些小细节,进而导致整个程序出现错误.下面就给大家介绍一下Python开发者 ...
- 零基础如何入门Python
编程零基础如何学习Python 如果你是零基础,注意是零基础,想入门编程的话,我推荐你学Python.虽然国内基本上是以C语言作为入门教学,但在麻省理工等国外大学都是以Python作为编程入门教学的. ...
- 零基础如何学Python爬虫技术?
在作者学习的众多编程技能中,爬虫技能无疑是最让作者着迷的.与自己闭关造轮子不同,爬虫的感觉是与别人博弈,一个在不停的构建 反爬虫 规则,一个在不停的破译规则. 如何入门爬虫?零基础如何学爬虫技术?那前 ...
- 零基础自学人工智能,看这些资料就够了(300G资料免费送)
为什么有今天这篇? 首先,标题不要太相信,哈哈哈. 本公众号之前已经就人工智能学习的路径.学习方法.经典学习视频等做过完整说明.但是鉴于每个人的基础不同,可能需要额外的学习资料进行辅助.特此,向大家免 ...
- 零基础自学用Python 3开发网络爬虫
原文出处: Jecvay Notes (@Jecvay) 由于本学期好多神都选了Cisco网络课, 而我这等弱渣没选, 去蹭了一节发现讲的内容虽然我不懂但是还是无爱. 我想既然都本科就出来工作还是按照 ...
- 零基础如何使用python处理字符串?
摘要:Python的普遍使用场景是自动化测试.爬取网页数据.科学分析之类,这其中都涉及到了对数据的处理,而数据的表现形式很多,今天我们来讲讲字符串的操作. 字符串是作为任意一门编程语言的基础,在P ...
随机推荐
- C++中UNIX时间戳与日期互转
C++中UNIX时间戳与日期互转 使用time.h头文件 localtime 可以把时间戳转为 tm 结构体, tm结构体中可以格式化输出时间 mktime可以把tm结构体转为时间戳 tm 结构体中: ...
- Linux 内核:设备驱动模型(2)driver-bus-device与probe
Linux 内核:设备驱动模型(2)driver-bus-device与probe 系列:Linux 内核:设备驱动模型 学习总结 参考: https://blog.csdn.net/lizuobin ...
- SNAT,DNAT以及REDIRECT转发详解
最近负责的其中一个项目的服务器集群出现了点网络方面的问题,在处理过程当中又涉及到了防火墙相关的知识和命令,想着有一段时间没有复习这部分内容了,于是借着此次机会复写了下顺便将本次复习的一些内容以博客的形 ...
- 【基础整理】Mapping representation 机器人所用地图种类及相关介绍
参考与前言 本文主要介绍 建图 Mapping 方面的一些 基础知识介绍与相关下游任务使用 涉及知识较为基础,SLAM大佬们可以提前退出了 主要针对应用为移动机器人与物流无人驾驶车:提前申明:大部分文 ...
- MySQL 索引失效
全列匹配 最佳左前缀法则 不在索引列上做任何操作(计算.函数.自动.手动类型转换),会导致索引失效 存储引擎不能使用索引中范围条件右边的列 尽量使用覆盖索引(只访问索引的查询(索引和查询列一致)),少 ...
- SpringMVC springmvc.xml配置路径前缀和后缀
web.xml <?xml version="1.0" encoding="UTF-8"?> <web-app xmlns:xsi=" ...
- redis出现错误提示MISCONF Redis is configured to save RDB snapshots, but is currently not able to persist on disk. Commands that may modify the data set are disabled. Please check Redis logs for......
某天,redis出现了这样一个错误提示: MISCONF Redis is configured to save RDB snapshots, but is currently not able to ...
- 洛谷P1378
这道题需要处理的信息比较多,需要注意的是一个油滴扩展后可能会包含其他的点 #include <iostream> #include <utility> #include < ...
- Webpack3.x升级至 4.x 小记
近期项目部署遇到点问题,需要升级webpack版本,特此整理一小记,记录升级过程中的依赖包及报错处理. 本次升级的依赖包及对应版本对照表: npm 包 当前版本 升级版本 S/D vue ^2.5.1 ...
- Linux 提权-密码搜寻
本文通过 Google 翻译 Password Hunting – Linux Privilege Escalation 这篇文章所产生,本人仅是对机器翻译中部分表达别扭的字词进行了校正及个别注释补充 ...