pytorch 入门
import matplotlib.pyplot as plt
from torchvision.transforms import ToTensor
import torch
from torch import nn  # 包含构建神经网络的所有模块
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
# 训练数据集
training_data = datasets.FashionMNIST(
    root="./data",          # 存储测试集和训练集的路径
    train=True,           # 训练集
    download=True,        # 如果本机没有数据集,就会下载到 root 目录下。
    transform=ToTensor()  # 对样本数据进行处理,转换为张量数据
)
# 测试数据集
test_data = datasets.FashionMNIST(
    root="./data",
    train=False,
    download=True,
    transform=ToTensor()
)
# 标签字典,一个key键对应一个label
labels_map = {
    0: "T-Shirt",
    1: "Trouser",
    2: "Pullover",
    3: "Dress",
    4: "Coat",
    5: "Sandal",
    6: "Shirt",
    7: "Sneaker",
    8: "Bag",
    9: "Ankle Boot",
}
# 设置画布大小
# figure = plt.figure(figsize=(8, 8))
# cols, rows = 3, 3
# for i in range(1, cols * rows + 1):
#     # 随机生成一个索引
#     sample_idx = torch.randint(len(training_data), size=(1,)).item()
#     # 获取样本及其对应的标签
#     img, label = training_data[sample_idx]
#     figure.add_subplot(rows, cols, i)
#     # 设置标题
#     plt.title(labels_map[label])
#     # 不显示坐标轴
#     plt.axis("off")
#     # 显示灰度图
#     plt.imshow(img.squeeze(), cmap="gray")
# plt.show()
# 训练数据加载器; 根据数据集生成一个迭代对象,用于模型的训练
train_dataloader = DataLoader(
    # 定义好的数据集
    dataset=training_data,
    # 设置批量大小
    batch_size=128,
    # 线程数,默认为0。在Windows下设置大于0的数可能会报错。
    num_workers=0,
    # 打乱样本的顺序
    shuffle=True)
# 测试数据加载器
test_dataloader = DataLoader(
    dataset=test_data,
    batch_size=128,
    shuffle=True)
# 展示图片和标签
# train_features, train_labels = next(iter(train_dataloader))
# print(f"Feature batch shape: {train_features.size()}")
# print(f"Labels batch shape: {train_labels.size()}")
# img = train_features[0].squeeze()
# label = train_labels[0]
# plt.imshow(img, cmap="gray")
# plt.show()
# print(f"Label: {label}")
# 模型定义
class Network(nn.Module):
    def __init__(self):
        super(Network, self).__init__()  # 执行父类中的 init 函数
        self.flatten = nn.Flatten()      # 将每个大小为28x28的图像转换为784个像素值的连续数组
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(in_features=28 * 28, out_features=512), # 线性层
            nn.ReLU(),
            nn.Linear(in_features=512, out_features=512),
            nn.ReLU(),
            nn.Linear(in_features=512, out_features=10),
        )
    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits
# 优化模型参数
def train_loop(dataloader, model, loss_func, optimizer, device):
    size = len(dataloader.dataset)
    for batch, (X, y) in enumerate(dataloader):
        X = X.to(device)
        y = y.to(device)
        # 前向传播,计算预测值
        pred = model(X)
        # 计算损失
        loss = loss_func(pred, y)
        # 反向传播,优化参数
        optimizer.zero_grad()  # 将模型的梯度归 0
        loss.backward()        # 用来存储每个参数的损失梯度
        optimizer.step()       # 梯度调整完以后,调整参数
        if batch % 100 == 0:
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")
# 测试模型性能
def test_loop(dataloader, model, loss_fn, device):
    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0
    with torch.no_grad():
        for X, y in dataloader:
            X = X.to(device)
            y = y.to(device)
            # 前向传播,计算预测值
            pred = model(X)
            # 计算损失
            test_loss += loss_fn(pred, y).item()
            # 计算准确率
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
    test_loss /= num_batches
    correct /= size
    print(
        f"Test Error: \n Accuracy: {(100 * correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
# input
if __name__ == '__main__':
    device = "cuda" if torch.cuda.is_available() else "cpu"
    print(f"Using {device} device")
    # 定义模型,并将模型移动到设备上
    model = Network().to(device)
    # 设置超参数
    learning_rate = 1e-3
    epochs = 20
    # 定义损失函数和优化器
    loss_func = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(params=model.parameters(), lr=learning_rate)
    # 训练模型
    for t in range(epochs):
        print(f"Epoch {t + 1}\n-------------------------------")
        train_loop(train_dataloader, model, loss_func, optimizer, device)
        test_loop(test_dataloader, model, loss_func, device)
    print("Done!")
    # 保存模型
    torch.save(model.state_dict(), 'model_weights.pth')
pytorch 入门的更多相关文章
- [pytorch] Pytorch入门
		Pytorch入门 简单容易上手,感觉比keras好理解多了,和mxnet很像(似乎mxnet有点借鉴pytorch),记一记. 直接从例子开始学,基础知识咱已经看了很多论文了... import t ... 
- Pytorch入门随手记
		Pytorch入门随手记 什么是Pytorch? Pytorch是Torch到Python上的移植(Torch原本是用Lua语言编写的) 是一个动态的过程,数据和图是一起建立的. tensor.dot ... 
- pytorch 入门指南
		两类深度学习框架的优缺点 动态图(PyTorch) 计算图的进行与代码的运行时同时进行的. 静态图(Tensorflow <2.0) 自建命名体系 自建时序控制 难以介入 使用深度学习框架的优点 ... 
- 超简单!pytorch入门教程(五):训练和测试CNN
		我们按照超简单!pytorch入门教程(四):准备图片数据集准备好了图片数据以后,就来训练一下识别这10类图片的cnn神经网络吧. 按照超简单!pytorch入门教程(三):构造一个小型CNN构建好一 ... 
- pytorch入门2.2构建回归模型初体验(开始训练)
		pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ... 
- pytorch入门2.0构建回归模型初体验(数据生成)
		pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ... 
- pytorch入门2.1构建回归模型初体验(模型构建)
		pytorch入门2.x构建回归模型系列: pytorch入门2.0构建回归模型初体验(数据生成) pytorch入门2.1构建回归模型初体验(模型构建) pytorch入门2.2构建回归模型初体验( ... 
- Pytorch入门——手把手教你MNIST手写数字识别
		MNIST手写数字识别教程 要开始带组内的小朋友了,特意出一个Pytorch教程来指导一下 [!] 这里是实战教程,默认读者已经学会了部分深度学习原理,若有不懂的地方可以先停下来查查资料 目录 MNI ... 
- Pytorch入门上 —— Dataset、Tensorboard、Transforms、Dataloader
		本节内容参照小土堆的pytorch入门视频教程.学习时建议多读源码,通过源码中的注释可以快速弄清楚类或函数的作用以及输入输出类型. Dataset 借用Dataset可以快速访问深度学习需要的数据,例 ... 
- Pytorch入门中 —— 搭建网络模型
		本节内容参照小土堆的pytorch入门视频教程,主要通过查询文档的方式讲解如何搭建卷积神经网络.学习时要学会查询文档,这样会比直接搜索良莠不齐的博客更快.更可靠.讲解的内容主要是pytorch核心包中 ... 
随机推荐
- .net 多地点计算中心点
			1.需求产生 快到周末了,几个远在各个区的朋友想要聚餐,为了照顾到彼此的距离,决定计算一下所有人的中心点,至此需求产生,下面开始编写代码. 2.编写代码 1)新建一个控制台程序 在NuGet程序包管理 ... 
- ubuntu18.04下联想电脑不能打开wifi
			一.问题描述: 本人使用联想拯救者14IFI笔记本在安装Ubuntu系统时会出现无线硬件开关关闭的问题,当然也就无法连网(Wi-Fi). (最好先使用 sudo rfkill unblock all) ... 
- 美团点评CAT部署了各种环境不下10次,遇到的坑整理
			CAT是什么 我的理解是一个收集服务调用等运行情况的监控系统. 相信你能搜到这篇博客我就不多介绍了,这里有链接 传送门 本博客仅仅只帮助大家解决部署方面的问题 来自一个用户的吐槽 1.部署真他娘的困难 ... 
- Prometheus插件安装(cadvisor)
			简介 当docker服务数量到一定程度,为了保证系统的文档,我们就需要对docker进行监控.一般情况下我们可以通过docker status命令来做简单的监控,但是无法交给prometheus采集, ... 
- NOIP2017 - D2T3 - phalanx
			按照思维难度加大和代码难度减小的顺序,我们来看这道题的不同做法. 若你无畏,我亦无畏 - 平衡树 平衡树简直是天然用来维护这种操作的--合并两个区间,提取一个值.我们可以对每个行的前 \(m-1\) ... 
- 基于docker的spark分布式与单线程、多线程wordcount的对比实验
			1. 分布式环境搭建 1.1 基于docker的spark配置文件 docker-compose.yml version: '2' services: spark: image: docker.io/ ... 
- EF core 反向工程 连接字符串
			Scaffold-DbContext 'Data Source=.;Initial Catalog=DB;User ID=sa;Password=1;Integrated Security=true; ... 
- Column count doesn't match value count at row 1存储的数据与数据库表的字段类型定义不相匹配
			一.造成这个原因可能是一个关于创建json数据类型的mysql表格插入的一个报错提示: 26行为错误示范:27是正确书写规范. 
- 解决:webpack打包js项目ie11浏览器下报promise 未定义
			项目背景:webpack+js+seajs 引入文件用require或者define 1.下载依赖包 npm install babel-polyfill 2.引入该依赖:webpack.conf ... 
- function 和mapped function的区别
			1 --在函数定义上使用mapped前缀将此函数标记为自动映射到集合上.这意味着,如果将集合作为函数的第一个参数,则该函数将在集合的元素上自动重复调用.这允许您定义脚本化函数,这些函数的行为方式与映射 ... 
