pytorch 入门
import matplotlib.pyplot as plt
from torchvision.transforms import ToTensor
import torch
from torch import nn # 包含构建神经网络的所有模块
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# 训练数据集
training_data = datasets.FashionMNIST(
root="./data", # 存储测试集和训练集的路径
train=True, # 训练集
download=True, # 如果本机没有数据集,就会下载到 root 目录下。
transform=ToTensor() # 对样本数据进行处理,转换为张量数据
)
# 测试数据集
test_data = datasets.FashionMNIST(
root="./data",
train=False,
download=True,
transform=ToTensor()
)
# 标签字典,一个key键对应一个label
labels_map = {
0: "T-Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",
}
# 设置画布大小
# figure = plt.figure(figsize=(8, 8))
# cols, rows = 3, 3
# for i in range(1, cols * rows + 1):
# # 随机生成一个索引
# sample_idx = torch.randint(len(training_data), size=(1,)).item()
# # 获取样本及其对应的标签
# img, label = training_data[sample_idx]
# figure.add_subplot(rows, cols, i)
# # 设置标题
# plt.title(labels_map[label])
# # 不显示坐标轴
# plt.axis("off")
# # 显示灰度图
# plt.imshow(img.squeeze(), cmap="gray")
# plt.show()
# 训练数据加载器; 根据数据集生成一个迭代对象,用于模型的训练
train_dataloader = DataLoader(
# 定义好的数据集
dataset=training_data,
# 设置批量大小
batch_size=128,
# 线程数,默认为0。在Windows下设置大于0的数可能会报错。
num_workers=0,
# 打乱样本的顺序
shuffle=True)
# 测试数据加载器
test_dataloader = DataLoader(
dataset=test_data,
batch_size=128,
shuffle=True)
# 展示图片和标签
# train_features, train_labels = next(iter(train_dataloader))
# print(f"Feature batch shape: {train_features.size()}")
# print(f"Labels batch shape: {train_labels.size()}")
# img = train_features[0].squeeze()
# label = train_labels[0]
# plt.imshow(img, cmap="gray")
# plt.show()
# print(f"Label: {label}")
# 模型定义
class Network(nn.Module):
def __init__(self):
super(Network, self).__init__() # 执行父类中的 init 函数
self.flatten = nn.Flatten() # 将每个大小为28x28的图像转换为784个像素值的连续数组
self.linear_relu_stack = nn.Sequential(
nn.Linear(in_features=28 * 28, out_features=512), # 线性层
nn.ReLU(),
nn.Linear(in_features=512, out_features=512),
nn.ReLU(),
nn.Linear(in_features=512, out_features=10),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
# 优化模型参数
def train_loop(dataloader, model, loss_func, optimizer, device):
size = len(dataloader.dataset)
for batch, (X, y) in enumerate(dataloader):
X = X.to(device)
y = y.to(device)
# 前向传播,计算预测值
pred = model(X)
# 计算损失
loss = loss_func(pred, y)
# 反向传播,优化参数
optimizer.zero_grad() # 将模型的梯度归 0
loss.backward() # 用来存储每个参数的损失梯度
optimizer.step() # 梯度调整完以后,调整参数
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
# 测试模型性能
def test_loop(dataloader, model, loss_fn, device):
size = len(dataloader.dataset)
num_batches = len(dataloader)
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
X = X.to(device)
y = y.to(device)
# 前向传播,计算预测值
pred = model(X)
# 计算损失
test_loss += loss_fn(pred, y).item()
# 计算准确率
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(
f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
# input
if __name__ == '__main__':
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
# 定义模型,并将模型移动到设备上
model = Network().to(device)
# 设置超参数
learning_rate = 1e-3
epochs = 20
# 定义损失函数和优化器
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(params=model.parameters(), lr=learning_rate)
# 训练模型
for t in range(epochs):
print(f"Epoch {t + 1}\n-------------------------------")
train_loop(train_dataloader, model, loss_func, optimizer, device)
test_loop(test_dataloader, model, loss_func, device)
print("Done!")
# 保存模型
torch.save(model.state_dict(), 'model_weights.pth')
pytorch 入门的更多相关文章
- [pytorch] Pytorch入门
Pytorch入门 简单容易上手,感觉比keras好理解多了,和mxnet很像(似乎mxnet有点借鉴pytorch),记一记. 直接从例子开始学,基础知识咱已经看了很多论文了... import t ...
- Pytorch入门随手记
Pytorch入门随手记 什么是Pytorch? Pytorch是Torch到Python上的移植(Torch原本是用Lua语言编写的) 是一个动态的过程,数据和图是一起建立的. tensor.dot ...
- pytorch 入门指南
两类深度学习框架的优缺点 动态图(PyTorch) 计算图的进行与代码的运行时同时进行的. 静态图(Tensorflow <2.0) 自建命名体系 自建时序控制 难以介入 使用深度学习框架的优点 ...
- 超简单!pytorch入门教程(五):训练和测试CNN
我们按照超简单!pytorch入门教程(四):准备图片数据集准备好了图片数据以后,就来训练一下识别这10类图片的cnn神经网络吧. 按照超简单!pytorch入门教程(三):构造一个小型CNN构建好一 ...
- pytorch入门2.2构建回归模型初体验(开始训练)
pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...
- pytorch入门2.0构建回归模型初体验(数据生成)
pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...
- pytorch入门2.1构建回归模型初体验(模型构建)
pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...
- Pytorch入门——手把手教你MNIST手写数字识别
MNIST手写数字识别教程 要开始带组内的小朋友了,特意出一个Pytorch教程来指导一下 [!] 这里是实战教程,默认读者已经学会了部分深度学习原理,若有不懂的地方可以先停下来查查资料 目录 MNI ...
- Pytorch入门上 —— Dataset、Tensorboard、Transforms、Dataloader
本节内容参照小土堆的pytorch入门视频教程.学习时建议多读源码,通过源码中的注释可以快速弄清楚类或函数的作用以及输入输出类型. Dataset 借用Dataset可以快速访问深度学习需要的数据,例 ...
- Pytorch入门中 —— 搭建网络模型
本节内容参照小土堆的pytorch入门视频教程,主要通过查询文档的方式讲解如何搭建卷积神经网络.学习时要学会查询文档,这样会比直接搜索良莠不齐的博客更快.更可靠.讲解的内容主要是pytorch核心包中 ...
随机推荐
- Ribbon服务调用+负载均衡(入门)
1.Ribbon Ribbon中文:(用于捆绑或装饰的)带子; 丝带; 带状物; 主要功能是提供客户端的软件负载均衡算法和服务调用 Ribbon已经进入了维护模式了,但是Ribbon仍然被广泛使用中 ...
- 大三末java实习生一面凉经
在南京投了一些小公司,想找个实习,因为知道自己很菜,就收到了一家公司的面试. 面试一般在线上进行,我是在腾讯会议上进行的.面试官其实挺好的,一般不会为难你,因为他知道你是在校生不会懂那么多企业的技术. ...
- What?JMeter做UI自动化!
JMeter做UI自动化 不推荐,好别扭,但可行 插件安装 搜插件selenium,安装 添加config 添加线程组 右键线程组->添加->配置元件->jp@gc - Chrome ...
- yarn的基础语法:yarn安装完vue cli3后提示不是内部命令
: 第一步:安装nodejs: 第二步:全局安装vue-cli 解决方案: 全局搜索vue.cmd 将vue.cmd所在的路径添加到环境变量Path后面.再执行vue -V即可.
- LeetCode 周赛 334,在算法的世界里反复横跳
本文已收录到 AndroidFamily,技术和职场问题,请关注公众号 [彭旭锐] 提问. 大家好,我是小彭. 今天是 LeetCode 第 334 场周赛,你参加了吗?这场周赛考察范围比较基础,整体 ...
- vulnhub靶场之MATRIX-BREAKOUT: 2 MORPHEUS
准备: 攻击机:虚拟机kali.本机win10. 靶机:Matrix-Breakout: 2 Morpheus,下载地址:https://download.vulnhub.com/matrix-bre ...
- pat乙级1023 组个最小数
#include <stdio.h> #include <stdlib.h> #include <string.h> #include <math.h> ...
- 如何设置QGraphicsItem线宽不随QGraphicsView缩放而变小或变大
很简单,只需要重写一下Item中的paint()方法 void my_line_item::paint(QPainter *painter, const QStyleOptionGraphicsIte ...
- 记一次oracle单表改分区表 一波三折
业务上要把单表还差分区表 ```SQL> @seg gwx.aopen SEG_MB OWNER SEGMENT_NAME SEG_PART_NAME SEGMENT_TYPE SEG_TABL ...
- 【Nginx】优化,增加线程
https://blog.csdn.net/cnskylee/article/details/127645806 众所周知,Nginx一款体积小巧,但是性能强大的软负载,主要被用作后端服务和应用的反向 ...