• 本次将一个使用Pytorch的一个实战项目,记录流程:自定义数据集->数据加载->搭建神经网络->迁移学习->保存模型->加载模型->测试模型

  • 自定义数据集

    参考我的上一篇博客:自定义数据集处理

  • 数据加载

    默认小伙伴有对深度学习框架有一定的了解,这里就不做过多的说明了。

    好吧,还是简单的说一下吧:

    我们在做好了自定义数据集之后,其实数据的加载和MNSIT 、CIFAR-10 、CIFAR-100等数据集的都是相似的,过程如下所示:

    • 导入必要的包
import torch
from torch import optim, nn
import visdom
from torch.utils.data import DataLoader
  • 加载数据

    可以发现和MNIST 、CIFAR的加载基本上是一样的
train_db = Pokemon('pokeman', 224, mode='train')
val_db = Pokemon('pokeman', 224, mode='val')
test_db = Pokemon('pokeman', 224, mode='test')
train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True,
num_workers=4)
val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=2)
test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=2)
  • 搭建神经网络

    ResNet-18网络结构:



    ResNet全名Residual Network残差网络。Kaiming He 的《Deep Residual Learning for Image Recognition》获得了CVPR最佳论文。他提出的深度残差网络在2015年可以说是洗刷了图像方面的各大比赛,以绝对优势取得了多个比赛的冠军。而且它在保证网络精度的前提下,将网络的深度达到了152层,后来又进一步加到1000的深度。论文的开篇先是说明了深度网络的好处:特征等级随着网络的加深而变高,网络的表达能力也会大大提高。因此论文中提出了一个问题:是否可以通过叠加网络层数来获得一个更好的网络呢?作者经过实验发现,单纯的把网络叠起来的深层网络的效果反而不如合适层数的较浅的网络效果。因此何恺明等人在普通平原网络的基础上增加了一个shortcut, 构成一个residual block。此时拟合目标就变为F(x),F(x)就是残差:

    • 训练模型
def evalute(model, loader):
model.eval()
correct = 0
total = len(loader.dataset)
for x, y in loader:
x, y = x.to(device), y.to(device)
with torch.no_grad():
logits = model(x)
pred = logits.argmax(dim=1)
correct += torch.eq(pred, y).sum().float().item()
return correct / total
def main():
model = ResNet18(5).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
criteon = nn.CrossEntropyLoss()
best_acc, best_epoch = 0, 0
global_step = 0
viz.line([0], [-1], win='loss', opts=dict(title='loss'))
viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))
for epoch in range(epochs):
for step, (x, y) in enumerate(train_loader):
x, y = x.to(device), y.to(device)
model.train()
logits = model(x)
loss = criteon(logits, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
viz.line([loss.item()], [global_step], win='loss', update='append')
global_step += 1
if epoch % 1 == 0:
val_acc = evalute(model, val_loader)
if val_acc > best_acc:
best_epoch = epoch
best_acc = val_acc
viz.line([val_acc], [global_step], win='val_acc', update='append')
print('best acc:', best_acc, 'best epoch:', best_epoch)
model.load_state_dict(torch.load('best.mdl'))
print('loaded from ckpt!')
test_acc = evalute(model, test_loader)
  • 迁移学习

    提升模型的准确率:

# model = ResNet18(5).to(device)
trained_model=resnet18(pretrained=True) # 此时是一个非常好的model
model = nn.Sequential(*list(trained_model.children())[:-1], # 此时使用的是前17层的网络 0-17 *:随机打散
Flatten(),
nn.Linear(512,5)
).to(device)
# x=torch.randn(2,3,224,224)
# print(model(x).shape)
optimizer = optim.Adam(model.parameters(), lr=lr)
criteon = nn.CrossEntropyLoss()
  • 保存、加载模型

    pytorch保存模型的方式有两种:

    第一种:将整个网络都都保存下来

    第二种:仅保存和加载模型参数(推荐使用这样的方法)
# 保存和加载整个模型
torch.save(model_object, 'model.pkl')
model = torch.load('model.pkl')
# 仅保存和加载模型参数(推荐使用)
torch.save(model_object.state_dict(), 'params.pkl')
model_object.load_state_dict(torch.load('params.pkl'))

可以看到这是我保存的模型:

其中best.mdl是第二中方法保存的

model.pkl则是第一种方法保存的

  • 测试模型

    这里是训练时的情况



    看这个数据准确率还是不错的,但是还是需要实际的测试这个模型,看它到底学到东西了没有,接下来简单的测试一下:

import torch
from PIL import Image
from torchvision import transforms
device = torch.device('cuda')
transform=transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],
std=[0.229,0.224,0.225])
])
def prediect(img_path):
net=torch.load('model.pkl')
net=net.to(device)
torch.no_grad()
img=Image.open(img_path)
img=transform(img).unsqueeze(0)
img_ = img.to(device)
outputs = net(img_)
_, predicted = torch.max(outputs, 1)
# print(predicted)
print('this picture maybe :',classes[predicted[0]])
if __name__ == '__main__':
prediect('./test/name.jpg')

实际的测试结果





效果还是可以的,完整的代码:

https://github.com/huzixuan1/Loader_DateSet

数据集下载链接:

https://pan.baidu.com/s/12-NQiF4fXEOKrXXdbdDPCg

由于笔者能力水平有限,在表述上可能有些不准确;有问题可以联系QQ:1017190168

PyTorch 实战(模型训练、模型加载、模型测试)的更多相关文章

  1. Pytorch实战学习(四):加载数据集

    <PyTorch深度学习实践>完结合集_哔哩哔哩_bilibili Dataset & Dataloader 1.Dataset & Dataloader作用 ※Datas ...

  2. 深度学习-05(tensorflow模型保存与加载、文件读取、图像分类:手写体识别、服饰识别)

    文章目录 深度学习-05 模型保存于加载 什么是模型保存与加载 模型保存于加载API 案例1:模型保存/加载 读取数据 文件读取机制 文件读取API 案例2:CSV文件读取 图片文件读取API 案例3 ...

  3. [译]Vulkan教程(31)加载模型

    [译]Vulkan教程(31)加载模型 Loading models 加载模型 Introduction 入门 Your program is now ready to render textured ...

  4. PyTorch保存模型与加载模型+Finetune预训练模型使用

    Pytorch 保存模型与加载模型 PyTorch之保存加载模型 参数初始化参 数的初始化其实就是对参数赋值.而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了da ...

  5. [PyTorch 学习笔记] 7.1 模型保存与加载

    本章代码: https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/model_save.py https://githu ...

  6. 全面解析Pytorch框架下模型存储,加载以及冻结

    最近在做试验中遇到了一些深度网络模型加载以及存储的问题,因此整理了一份比较全面的在 PyTorch 框架下有关模型的问题.首先咱们先定义一个网络来进行后续的分析: 1.本文通用的网络模型 import ...

  7. 深度学习原理与框架-猫狗图像识别-卷积神经网络(代码) 1.cv2.resize(图片压缩) 2..get_shape()[1:4].num_elements(获得最后三维度之和) 3.saver.save(训练参数的保存) 4.tf.train.import_meta_graph(加载模型结构) 5.saver.restore(训练参数载入)

    1.cv2.resize(image, (image_size, image_size), 0, 0, cv2.INTER_LINEAR) 参数说明:image表示输入图片,image_size表示变 ...

  8. [Pytorch]Pytorch 保存模型与加载模型(转)

    转自:知乎 目录: 保存模型与加载模型 冻结一部分参数,训练另一部分参数 采用不同的学习率进行训练 1.保存模型与加载 简单的保存与加载方法: # 保存整个网络 torch.save(net, PAT ...

  9. tensorflow 模型保存与加载 和TensorFlow serving + grpc + docker项目部署

    TensorFlow 模型保存与加载 TensorFlow中总共有两种保存和加载模型的方法.第一种是利用 tf.train.Saver() 来保存,第二种就是利用 SavedModel 来保存模型,接 ...

  10. tensorflow实现线性回归、以及模型保存与加载

    内容:包含tensorflow变量作用域.tensorboard收集.模型保存与加载.自定义命令行参数 1.知识点 """ 1.训练过程: 1.准备好特征和目标值 2.建 ...

随机推荐

  1. Elasticsearch之常用术语

    一. 数据库和ES简单类比 关系型数据库 表(Table) 行(Row) 列(Cloumn) Schema SQL Elasticsearch 索引(Index) 文档(Document) 字段(Fi ...

  2. HiAI Foundation助力端侧音视频AI能力,高性能低功耗释放云侧成本

    过去三年是端侧AI高速发展的几年,华为在2020年预言了端侧AI的发展潮流,2021年通过提供端云协同的方式使我们的HiAI Foundation应用性更进一个台阶,2022年提供视频超分端到端的解决 ...

  3. 初露头角!Walrus入选服贸会“数智影响力”数字化转型创新案例

    9月5日,由北京市通信管理局.工业和信息化部新闻宣传中心联合主办的"企业数字化转型论坛"在2023中国国际服务贸易交易会期间召开,论坛以"数字化引领 高质量发展" ...

  4. 利用别名简化进入docker容器数据库的操作

    之前研究docker和数据库的交互,越发对docker这个东西喜爱了.因为平常偶尔会用到各类数据库测试环境验证一些想法,需要进一步简化进入到这些环境的步骤. 比如我现在有三套docker容器数据库测试 ...

  5. Vue3+vite路由配置优化(自动化导入)

    今天在维护优化公司中台项目时,发现路由的文件配置非常多非常乱,只要只中大型项目,都会进入很多的路由页面,规范一点的公司还会吧路由进行模块化导入,但是依然存在很多文件夹的和手动导入的问题. 于是我想到了 ...

  6. Record - Nov. 28st, 2020 - Exam. REC

    Prob. 1 Desc. & Link. 暴力为 \(\Theta(NK)\). 正解(也许): 把每一个全为正整数的子段找出来. 然后判断一下中间连接的情况即可. 但是这样决策情况太多了. ...

  7. 如何在没有第三方.NET库源码的情况,调试第三库代码?

    大家好,我是沙漠尽头的狼. 本方首发于Dotnet9,介绍使用dnSpy调试第三方.NET库源码,行文目录: 安装dnSpy 编写示例程序 调试示例程序 调试.NET库原生方法 总结 1. 安装dnS ...

  8. 关于wake on lan远程唤醒主机的问题,长时间关机无法远程唤醒

    英特尔在年初发布了几款低功耗的CPU,国内厂商在迷你主机领域纷纷搭载新款CPU,卖的火爆.之前关注过迷你主机这块,于是,我也入手一个迷你主机玩玩,买的是板载N100的迷你主机.使用过程中会涉及到如何远 ...

  9. 【matplotlib 实战】--面积图

    面积图,或称区域图,是一种随有序变量的变化,反映数值变化的统计图表. 面积图也可用于多个系列数据的比较.这时,面积图的外观看上去类似层叠的山脉,在错落有致的外形下表达数据的总量和趋势.面积图不仅可以清 ...

  10. 洛谷题解 | P1046 陶陶摘苹果

    ​ 目录 题目描述 输入格式 输出格式 输入输出样例 说明/提示 题目思路 AC代码 题目描述 陶陶家的院子里有一棵苹果树,每到秋天树上就会结出 10 个苹果.苹果成熟的时候,陶陶就会跑去摘苹果.陶陶 ...