pytorch 入门
import matplotlib.pyplot as plt
from torchvision.transforms import ToTensor
import torch
from torch import nn # 包含构建神经网络的所有模块
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# 训练数据集
training_data = datasets.FashionMNIST(
root="./data", # 存储测试集和训练集的路径
train=True, # 训练集
download=True, # 如果本机没有数据集,就会下载到 root 目录下。
transform=ToTensor() # 对样本数据进行处理,转换为张量数据
)
# 测试数据集
test_data = datasets.FashionMNIST(
root="./data",
train=False,
download=True,
transform=ToTensor()
)
# 标签字典,一个key键对应一个label
labels_map = {
0: "T-Shirt",
1: "Trouser",
2: "Pullover",
3: "Dress",
4: "Coat",
5: "Sandal",
6: "Shirt",
7: "Sneaker",
8: "Bag",
9: "Ankle Boot",
}
# 设置画布大小
# figure = plt.figure(figsize=(8, 8))
# cols, rows = 3, 3
# for i in range(1, cols * rows + 1):
# # 随机生成一个索引
# sample_idx = torch.randint(len(training_data), size=(1,)).item()
# # 获取样本及其对应的标签
# img, label = training_data[sample_idx]
# figure.add_subplot(rows, cols, i)
# # 设置标题
# plt.title(labels_map[label])
# # 不显示坐标轴
# plt.axis("off")
# # 显示灰度图
# plt.imshow(img.squeeze(), cmap="gray")
# plt.show()
# 训练数据加载器; 根据数据集生成一个迭代对象,用于模型的训练
train_dataloader = DataLoader(
# 定义好的数据集
dataset=training_data,
# 设置批量大小
batch_size=128,
# 线程数,默认为0。在Windows下设置大于0的数可能会报错。
num_workers=0,
# 打乱样本的顺序
shuffle=True)
# 测试数据加载器
test_dataloader = DataLoader(
dataset=test_data,
batch_size=128,
shuffle=True)
# 展示图片和标签
# train_features, train_labels = next(iter(train_dataloader))
# print(f"Feature batch shape: {train_features.size()}")
# print(f"Labels batch shape: {train_labels.size()}")
# img = train_features[0].squeeze()
# label = train_labels[0]
# plt.imshow(img, cmap="gray")
# plt.show()
# print(f"Label: {label}")
# 模型定义
class Network(nn.Module):
def __init__(self):
super(Network, self).__init__() # 执行父类中的 init 函数
self.flatten = nn.Flatten() # 将每个大小为28x28的图像转换为784个像素值的连续数组
self.linear_relu_stack = nn.Sequential(
nn.Linear(in_features=28 * 28, out_features=512), # 线性层
nn.ReLU(),
nn.Linear(in_features=512, out_features=512),
nn.ReLU(),
nn.Linear(in_features=512, out_features=10),
)
def forward(self, x):
x = self.flatten(x)
logits = self.linear_relu_stack(x)
return logits
# 优化模型参数
def train_loop(dataloader, model, loss_func, optimizer, device):
size = len(dataloader.dataset)
for batch, (X, y) in enumerate(dataloader):
X = X.to(device)
y = y.to(device)
# 前向传播,计算预测值
pred = model(X)
# 计算损失
loss = loss_func(pred, y)
# 反向传播,优化参数
optimizer.zero_grad() # 将模型的梯度归 0
loss.backward() # 用来存储每个参数的损失梯度
optimizer.step() # 梯度调整完以后,调整参数
if batch % 100 == 0:
loss, current = loss.item(), batch * len(X)
print(f"loss: {loss:>7f} [{current:>5d}/{size:>5d}]")
# 测试模型性能
def test_loop(dataloader, model, loss_fn, device):
size = len(dataloader.dataset)
num_batches = len(dataloader)
test_loss, correct = 0, 0
with torch.no_grad():
for X, y in dataloader:
X = X.to(device)
y = y.to(device)
# 前向传播,计算预测值
pred = model(X)
# 计算损失
test_loss += loss_fn(pred, y).item()
# 计算准确率
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(
f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
# input
if __name__ == '__main__':
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
# 定义模型,并将模型移动到设备上
model = Network().to(device)
# 设置超参数
learning_rate = 1e-3
epochs = 20
# 定义损失函数和优化器
loss_func = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(params=model.parameters(), lr=learning_rate)
# 训练模型
for t in range(epochs):
print(f"Epoch {t + 1}\n-------------------------------")
train_loop(train_dataloader, model, loss_func, optimizer, device)
test_loop(test_dataloader, model, loss_func, device)
print("Done!")
# 保存模型
torch.save(model.state_dict(), 'model_weights.pth')
pytorch 入门的更多相关文章
- [pytorch] Pytorch入门
Pytorch入门 简单容易上手,感觉比keras好理解多了,和mxnet很像(似乎mxnet有点借鉴pytorch),记一记. 直接从例子开始学,基础知识咱已经看了很多论文了... import t ...
- Pytorch入门随手记
Pytorch入门随手记 什么是Pytorch? Pytorch是Torch到Python上的移植(Torch原本是用Lua语言编写的) 是一个动态的过程,数据和图是一起建立的. tensor.dot ...
- pytorch 入门指南
两类深度学习框架的优缺点 动态图(PyTorch) 计算图的进行与代码的运行时同时进行的. 静态图(Tensorflow <2.0) 自建命名体系 自建时序控制 难以介入 使用深度学习框架的优点 ...
- 超简单!pytorch入门教程(五):训练和测试CNN
我们按照超简单!pytorch入门教程(四):准备图片数据集准备好了图片数据以后,就来训练一下识别这10类图片的cnn神经网络吧. 按照超简单!pytorch入门教程(三):构造一个小型CNN构建好一 ...
- pytorch入门2.2构建回归模型初体验(开始训练)
pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...
- pytorch入门2.0构建回归模型初体验(数据生成)
pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...
- pytorch入门2.1构建回归模型初体验(模型构建)
pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ...
- Pytorch入门——手把手教你MNIST手写数字识别
MNIST手写数字识别教程 要开始带组内的小朋友了,特意出一个Pytorch教程来指导一下 [!] 这里是实战教程,默认读者已经学会了部分深度学习原理,若有不懂的地方可以先停下来查查资料 目录 MNI ...
- Pytorch入门上 —— Dataset、Tensorboard、Transforms、Dataloader
本节内容参照小土堆的pytorch入门视频教程.学习时建议多读源码,通过源码中的注释可以快速弄清楚类或函数的作用以及输入输出类型. Dataset 借用Dataset可以快速访问深度学习需要的数据,例 ...
- Pytorch入门中 —— 搭建网络模型
本节内容参照小土堆的pytorch入门视频教程,主要通过查询文档的方式讲解如何搭建卷积神经网络.学习时要学会查询文档,这样会比直接搜索良莠不齐的博客更快.更可靠.讲解的内容主要是pytorch核心包中 ...
随机推荐
- Spring Boot自动配置原理懂后轻松写一个自己的starter
目前很多Spring项目的开发都会直接用到Spring Boot.因为Spring原生开发需要加太多的配置,而使用Spring Boot开发很容易上手,只需遵循Spring Boot开发的约定就行了, ...
- 利用canvas+js完成滑块验证码中canvas部分思路
1. 最终效果 2.滑块验证码思路 大概思路:设置两个画布,一个为显示图像的canvas画布,一个为拼图的block画布,block画布拼图内容从图像画布中的一部分裁剪得到(使用clip()),通过绑 ...
- 【KAWAKO】speechmetrics-语音方面评价指标库的安装与使用
目录 简介 安装 将工程以压缩包形式下载到本地 把压缩包传到服务器(你想部署的地方)上进行解压 用编辑器打开setup.py进行修改 在工程目录下进行安装 测试 简介 speechmetrics库提供 ...
- jenkins简单安装及配置(Windows环境
jenkins简单安装及配置(Windows环境) jenkins是一款跨平台的持续集成和持续交付.基于Java开发的开源软件,提供任务构建.持续集成监控的功能,可以使开发测试人员更方便的构建软件项目 ...
- JVM相关知识学习
JVM的垃圾回收算法是什么? 分代回收算法:然后详细阐述年轻代有哪些算法,老年代有哪些算法 垃圾收集器总结: 最初使用的是Serial + Serial Old收集垃圾,最简单,因为二者都是单线程的, ...
- 《爆肝整理》保姆级系列教程-玩转Charles抓包神器教程(8)-Charles如何进行断点调试
1.简介 Charles和Fiddler一样也有个强大的功能,可以修改发送到服务器的数据包,但是修改前需要拦截,即设置断点.设置断点后,开始拦截接下来所有网页,直到取消断点.这个功能可以在数据包发送之 ...
- vue路由中pdfjs插件使用及找不到 viewer.html解决
官方下载: https://mozilla.github.io/pdf.js/getting_started/#download 同目录下pdfjs-2.12.313-dist.zip为官方下载包 此 ...
- 探索 C 语言的指针
指针的概念 指针代表一个变量的内存地址,通过&可以拿到变量的内存地址.变量不等于指针,通过*可以拿到指针所指向的变量的值. 在 C 中,存在指针变量,指针变量的声明格式:int* varNam ...
- qt_2d画图
qt中如何画图? 使用Qpainter类进行2D绘画 例如以下代码进行直线的绘制 QPainter painter(this);painter.setPen(Qt::red);painter.draw ...
- 微信小程序扫码
前言:微信小程序-->调用摄像头,扫描二维码/条形码,并获取信息,一连串操作,只需要调用微信小程序提供的 wx.scanCode API. 一.生成测试二维码 随便网上找个二维码生成器. 二.实 ...