前言

本文介绍使用神经网络进行实战。

使用的代码是《零基础学习人工智能—Python—Pytorch学习(九)》里的代码。

代码实现

mudule定义

首先我们自定义一个module,创建一个torch_test17_Model.py文件(这个module要单独用个py文件定义),如下:

import torch.nn as nn
import torch.nn.functional as F class ConvNet(nn.Module):
def __init__(self):
super(ConvNet, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16*5*5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10) def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1, 16*5*5)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
return x

module创建

编写创建module的py文件,代码如下:

import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch_test17_Model as tm device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') input_size = 784
hidden_size = 100
num_classes = 10
batch_size = 100
learning_rate = 0.001
num_epochs = 200 # 要训练200-400轮效果最好 transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) train_dataset = torchvision.datasets.CIFAR10(
root='./data', train=True, download=True, transform=transform) train_loader = torch.utils. data.DataLoader(
dataset=train_dataset, batch_size=batch_size, shuffle=True) model = tm.ConvNet().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) n_total_steps = len(train_loader)
print("number total epochs(训练的回合):",num_epochs)
print("number total steps(训练的次数):",n_total_steps) for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
# images.shape: torch.Size([100, 3, 32, 32])
# images张量的四个维度是(B, C, H, W)
# B 是批量大小(即图像的数量)。
# C 是图像的通道数(例如,RGB 图像的通道数是 3)。
# H 和 W 分别是图像的高度和宽度。
print("images.shape:", images.shape) #100行,后面的维度是3,32,32。这个是图片信息。
# lables是对应images这100个图片的标签
print("labels.shape:", labels.shape)
print("labels[0].item():", labels[0].item()) # 输出例子 labels[0].item()=6
images = images.to(device)
labels = labels.to(device)
# 前向传播
outputs = model(images)
loss = criterion(outputs, labels)
print("loss.item()",loss.item()) # 输出例子 loss.item()=2.300053596496582
# 逆向传播和优化
optimizer.zero_grad()
loss.backward() ##执行逆向传播 会使用criterion的函数关系求偏导,然后把x的值,带入偏导公式求值,然后再乘以loss,得到新x值
optimizer.step()
print(f'训练轮次Epoch [{epoch}/{num_epochs}], Step [{i+1}/{n_total_steps}], Loss: {loss.item():.4f}')
print('==================')
print('训练结束') filePath = "model.pth" #没有路径,会保存到python文件所在目录
torch.save(model, filePath)
print('保存完成')

代码会输出loss的值,我们要重点关注这个值。

Loss 值越大,表示模型的预测与真实标签之间的差距较大,模型的性能较差。

Loss 值越小,表示模型的预测更接近真实标签,性能逐渐提高。

即,loss值接近0的时候,这个模型就可以用了。

module使用

编写使用module验证图片的py文件,注意要引用torch_test17_Model.py文件,代码如下:

import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch_test17_Model as tm device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') batch_size = 100 transform = transforms.Compose(
[transforms.Resize((32, 32)),# 如果预测时处理的图片尺寸与训练时不同,如评估输入的图片尺寸为 [100, 3, 64, 64],而模型训练使用的尺寸是 [100, 3, 32, 32],可以用這個转换一下
transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) test_dataset = torchvision.datasets.CIFAR10(
root='./data', train=False, download=True, transform=transform) test_loader = torch.utils.data.DataLoader(
dataset=test_dataset, batch_size=batch_size, shuffle=False) filePath = "model.pth" #没有路径,会保存到python文件所在目录
model = torch.load(filePath,weights_only=False)
model.eval() # 切换到评估模式 ############################使用阈值判断######################################
threshold = 0.7 # 设定一个阈值,表示模型的信心度,用阈值判断的话,要求模型必须更精确,如果只是两轮的训练,会出现全部判定不过去的情况
with torch.no_grad():
for images, labels in test_loader:
print("############################判断######################################")
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
print("outputs.shape",outputs.shape)
# 计算 softmax 概率
probabilities = F.softmax(outputs, dim=1) max_probs, predicted = torch.max(probabilities, 1)
for i in range(len(predicted)):
if max_probs[i] < threshold: # 如果置信度低于阈值,认为是未知类别
print(f"图片 {i} 被认为是未知类别,置信度 {max_probs[i]:.4f}")
else:
print(f"图片 {i} 被认为是类别 {predicted[i]},置信度 {max_probs[i]:.4f}")

判断图片是什么的时候,使用阈值模式。

结语

到此,我们对于神经网络,卷积神经网络,深度网络都有了一定了解。

然后我们就可以继续学习transformer了。


传送门:

零基础学习人工智能—Python—Pytorch学习—全集


注:此文章为原创,任何形式的转载都请联系作者获得授权并注明出处!



若您觉得这篇文章还不错,请点击下方的【推荐】,非常感谢!

https://www.cnblogs.com/kiba/p/18609581

零基础学习人工智能—Python—Pytorch学习(十二)的更多相关文章

  1. 如何零基础开始自学Python编程

    转载——原作者:赛门喵 链接:https://www.zhihu.com/question/29138020/answer/141170242 0. 明确目标 我是真正零基础开始学Python的,从一 ...

  2. 零基础快速掌握Python系统管理视频课程【猎豹网校】

    点击了解更多Python课程>>> 零基础快速掌握Python系统管理视频课程[猎豹网校] 课程目录 01.第01章 Python简介.mp4 02.第02章 IPython基础.m ...

  3. 进击的Python【第十二章】:mysql介绍与简单操作,sqlachemy介绍与简单应用

    进击的Python[第十二章]:mysql介绍与简单操作,sqlachemy介绍与简单应用 一.数据库介绍 什么是数据库? 数据库(Database)是按照数据结构来组织.存储和管理数据的仓库,每个数 ...

  4. 零基础的人该怎么学习JAVA

    对于JAVA有所兴趣但又是零基础的人,该如何学习JAVA呢?对于想要学习开发技术的学子来说找到一个合适自己的培训机构是非常难的事情,在选择的过程中总是  因为这样或那样的问题让你犹豫不决,阻碍你前进的 ...

  5. 零基础如何学Python爬虫技术?

    在作者学习的众多编程技能中,爬虫技能无疑是最让作者着迷的.与自己闭关造轮子不同,爬虫的感觉是与别人博弈,一个在不停的构建 反爬虫 规则,一个在不停的破译规则. 如何入门爬虫?零基础如何学爬虫技术?那前 ...

  6. 零基础学完Python的7大就业方向,哪个赚钱多?

    “ 我想学 Python,但是学完 Python 后都能干啥 ?” “ 现在学 Python,哪个方向最简单?哪个方向最吃香 ?” “ …… ” 相信不少 Python 的初学者,都会遇到上面的这些问 ...

  7. 零基础怎么学Python编程,新手常犯哪些错误?

    Python是人工智能时代最佳的编程语言,入门简单.功能强大,深获初学者的喜爱. 很多零基础学习Python开发的人都会忽视一些小细节,进而导致整个程序出现错误.下面就给大家介绍一下Python开发者 ...

  8. 零基础如何入门Python

    编程零基础如何学习Python 如果你是零基础,注意是零基础,想入门编程的话,我推荐你学Python.虽然国内基本上是以C语言作为入门教学,但在麻省理工等国外大学都是以Python作为编程入门教学的. ...

  9. 零基础自学人工智能,看这些资料就够了(300G资料免费送)

    为什么有今天这篇? 首先,标题不要太相信,哈哈哈. 本公众号之前已经就人工智能学习的路径.学习方法.经典学习视频等做过完整说明.但是鉴于每个人的基础不同,可能需要额外的学习资料进行辅助.特此,向大家免 ...

  10. 零基础自学用Python 3开发网络爬虫

    原文出处: Jecvay Notes (@Jecvay) 由于本学期好多神都选了Cisco网络课, 而我这等弱渣没选, 去蹭了一节发现讲的内容虽然我不懂但是还是无爱. 我想既然都本科就出来工作还是按照 ...

随机推荐

  1. 11-02 NOIP练习赛

    11-02 NOIP练习赛 为什么休息的天还要打练习赛,这不公平!!!!!!!!!! oh no! 但是三道题确实挺简单,也少见的很有意思. [USACO23OPEN] Milk Sum S 题面翻译 ...

  2. SXYZ-7.3训练赛

    T1 房 啥啥啥,T1又又又爆了,整个人精神状态 良好. 解题思路 考虑数据保证任意两个房子不重合 建一个结构体存两边 最后判断一下 \(>t\) 加两个 \(==t\) 加一个 == 但是!! ...

  3. Oracle ADG 自动切换脚本分享

    为大家分享一个[Oracle ADG自动切换]的脚本,由云和恩墨工程师HongyeDBA编写,支持Switchover.Failover. 下载链接:https://www.modb.pro/down ...

  4. C# 中的四种整形数据

    // C# 中有四种整数类型 byte short int long byte bMax = byte.MaxValue; /// 255 最大值 byte bMin = byte.MinValue; ...

  5. ADO.NET 连接数据库 【vs2022 + sqlServer】

    using System.Data; using System.Data.SqlClient; namespace Zhu.ADO.NET { internal class Program { pri ...

  6. ribbon配置负载均衡策略

    ribbon的负载均衡策略 com.netflix.loadbalancer.RandomRule:从提供服务的实例中以随机的方式: com.netflix.loadbalancer.RoundRob ...

  7. Linux命令netstat查看端口使用方法

    [redis@fgedu180 ~]$ netstat -an|grep 6379 tcp 0 0 192.168.4.180:6379 0.0.0.0:* LISTEN

  8. java程序设置开机自启

    Linux系统jar包开机自启 第一步:创建service文件 sudo nano etc/systemd/system/myapp.service 第二步:将下面代码复制到刚才创建的文件里面,保存 ...

  9. Cartographer学习——地图概率更新过程

    前言:最近一直在研究建图,对google的开源SLAM框架 Cartographer 进行了源码梳理,发现很多巧妙的算法设计,结合原论文 <Real-time Loop Closure in 2 ...

  10. 工作中的技术总结_JQuery_20210825

    工作中的技术总结_JQuery_20210825 JQuery此前接触不多,所以先把此次接触的一些基本操作 1.DOM节点的取值或者赋值: 语法: $(selector).val(value) 参数 ...