• 本次将一个使用Pytorch的一个实战项目,记录流程:自定义数据集->数据加载->搭建神经网络->迁移学习->保存模型->加载模型->测试模型

  • 自定义数据集

    参考我的上一篇博客:自定义数据集处理

  • 数据加载

    默认小伙伴有对深度学习框架有一定的了解,这里就不做过多的说明了。

    好吧,还是简单的说一下吧:

    我们在做好了自定义数据集之后,其实数据的加载和MNSIT 、CIFAR-10 、CIFAR-100等数据集的都是相似的,过程如下所示:

    • 导入必要的包
import torch
from torch import optim, nn
import visdom
from torch.utils.data import DataLoader
  • 加载数据

    可以发现和MNIST 、CIFAR的加载基本上是一样的
train_db = Pokemon('pokeman', 224, mode='train')
val_db = Pokemon('pokeman', 224, mode='val')
test_db = Pokemon('pokeman', 224, mode='test')
train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True,
num_workers=4)
val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=2)
test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=2)
  • 搭建神经网络

    ResNet-18网络结构:



    ResNet全名Residual Network残差网络。Kaiming He 的《Deep Residual Learning for Image Recognition》获得了CVPR最佳论文。他提出的深度残差网络在2015年可以说是洗刷了图像方面的各大比赛,以绝对优势取得了多个比赛的冠军。而且它在保证网络精度的前提下,将网络的深度达到了152层,后来又进一步加到1000的深度。论文的开篇先是说明了深度网络的好处:特征等级随着网络的加深而变高,网络的表达能力也会大大提高。因此论文中提出了一个问题:是否可以通过叠加网络层数来获得一个更好的网络呢?作者经过实验发现,单纯的把网络叠起来的深层网络的效果反而不如合适层数的较浅的网络效果。因此何恺明等人在普通平原网络的基础上增加了一个shortcut, 构成一个residual block。此时拟合目标就变为F(x),F(x)就是残差:

    • 训练模型
def evalute(model, loader):
model.eval()
correct = 0
total = len(loader.dataset)
for x, y in loader:
x, y = x.to(device), y.to(device)
with torch.no_grad():
logits = model(x)
pred = logits.argmax(dim=1)
correct += torch.eq(pred, y).sum().float().item()
return correct / total
def main():
model = ResNet18(5).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)
criteon = nn.CrossEntropyLoss()
best_acc, best_epoch = 0, 0
global_step = 0
viz.line([0], [-1], win='loss', opts=dict(title='loss'))
viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))
for epoch in range(epochs):
for step, (x, y) in enumerate(train_loader):
x, y = x.to(device), y.to(device)
model.train()
logits = model(x)
loss = criteon(logits, y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
viz.line([loss.item()], [global_step], win='loss', update='append')
global_step += 1
if epoch % 1 == 0:
val_acc = evalute(model, val_loader)
if val_acc > best_acc:
best_epoch = epoch
best_acc = val_acc
viz.line([val_acc], [global_step], win='val_acc', update='append')
print('best acc:', best_acc, 'best epoch:', best_epoch)
model.load_state_dict(torch.load('best.mdl'))
print('loaded from ckpt!')
test_acc = evalute(model, test_loader)
  • 迁移学习

    提升模型的准确率:

# model = ResNet18(5).to(device)
trained_model=resnet18(pretrained=True) # 此时是一个非常好的model
model = nn.Sequential(*list(trained_model.children())[:-1], # 此时使用的是前17层的网络 0-17 *:随机打散
Flatten(),
nn.Linear(512,5)
).to(device)
# x=torch.randn(2,3,224,224)
# print(model(x).shape)
optimizer = optim.Adam(model.parameters(), lr=lr)
criteon = nn.CrossEntropyLoss()
  • 保存、加载模型

    pytorch保存模型的方式有两种:

    第一种:将整个网络都都保存下来

    第二种:仅保存和加载模型参数(推荐使用这样的方法)
# 保存和加载整个模型
torch.save(model_object, 'model.pkl')
model = torch.load('model.pkl')
# 仅保存和加载模型参数(推荐使用)
torch.save(model_object.state_dict(), 'params.pkl')
model_object.load_state_dict(torch.load('params.pkl'))

可以看到这是我保存的模型:

其中best.mdl是第二中方法保存的

model.pkl则是第一种方法保存的

  • 测试模型

    这里是训练时的情况



    看这个数据准确率还是不错的,但是还是需要实际的测试这个模型,看它到底学到东西了没有,接下来简单的测试一下:

import torch
from PIL import Image
from torchvision import transforms
device = torch.device('cuda')
transform=transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],
std=[0.229,0.224,0.225])
])
def prediect(img_path):
net=torch.load('model.pkl')
net=net.to(device)
torch.no_grad()
img=Image.open(img_path)
img=transform(img).unsqueeze(0)
img_ = img.to(device)
outputs = net(img_)
_, predicted = torch.max(outputs, 1)
# print(predicted)
print('this picture maybe :',classes[predicted[0]])
if __name__ == '__main__':
prediect('./test/name.jpg')

实际的测试结果





效果还是可以的,完整的代码:

https://github.com/huzixuan1/Loader_DateSet

数据集下载链接:

https://pan.baidu.com/s/12-NQiF4fXEOKrXXdbdDPCg

由于笔者能力水平有限,在表述上可能有些不准确;有问题可以联系QQ:1017190168

PyTorch 实战(模型训练、模型加载、模型测试)的更多相关文章

  1. Pytorch实战学习(四):加载数据集

    <PyTorch深度学习实践>完结合集_哔哩哔哩_bilibili Dataset & Dataloader 1.Dataset & Dataloader作用 ※Datas ...

  2. 深度学习-05(tensorflow模型保存与加载、文件读取、图像分类:手写体识别、服饰识别)

    文章目录 深度学习-05 模型保存于加载 什么是模型保存与加载 模型保存于加载API 案例1:模型保存/加载 读取数据 文件读取机制 文件读取API 案例2:CSV文件读取 图片文件读取API 案例3 ...

  3. [译]Vulkan教程(31)加载模型

    [译]Vulkan教程(31)加载模型 Loading models 加载模型 Introduction 入门 Your program is now ready to render textured ...

  4. PyTorch保存模型与加载模型+Finetune预训练模型使用

    Pytorch 保存模型与加载模型 PyTorch之保存加载模型 参数初始化参 数的初始化其实就是对参数赋值.而我们需要学习的参数其实都是Variable,它其实是对Tensor的封装,同时提供了da ...

  5. [PyTorch 学习笔记] 7.1 模型保存与加载

    本章代码: https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson7/model_save.py https://githu ...

  6. 全面解析Pytorch框架下模型存储,加载以及冻结

    最近在做试验中遇到了一些深度网络模型加载以及存储的问题,因此整理了一份比较全面的在 PyTorch 框架下有关模型的问题.首先咱们先定义一个网络来进行后续的分析: 1.本文通用的网络模型 import ...

  7. 深度学习原理与框架-猫狗图像识别-卷积神经网络(代码) 1.cv2.resize(图片压缩) 2..get_shape()[1:4].num_elements(获得最后三维度之和) 3.saver.save(训练参数的保存) 4.tf.train.import_meta_graph(加载模型结构) 5.saver.restore(训练参数载入)

    1.cv2.resize(image, (image_size, image_size), 0, 0, cv2.INTER_LINEAR) 参数说明:image表示输入图片,image_size表示变 ...

  8. [Pytorch]Pytorch 保存模型与加载模型(转)

    转自:知乎 目录: 保存模型与加载模型 冻结一部分参数,训练另一部分参数 采用不同的学习率进行训练 1.保存模型与加载 简单的保存与加载方法: # 保存整个网络 torch.save(net, PAT ...

  9. tensorflow 模型保存与加载 和TensorFlow serving + grpc + docker项目部署

    TensorFlow 模型保存与加载 TensorFlow中总共有两种保存和加载模型的方法.第一种是利用 tf.train.Saver() 来保存,第二种就是利用 SavedModel 来保存模型,接 ...

  10. tensorflow实现线性回归、以及模型保存与加载

    内容:包含tensorflow变量作用域.tensorboard收集.模型保存与加载.自定义命令行参数 1.知识点 """ 1.训练过程: 1.准备好特征和目标值 2.建 ...

随机推荐

  1. 遗传算法解决航路规划问题(MATLAB)

    遗传算法 文章部分图片和思路来自司守奎,孙兆亮<数学建模算法与应用>第二版 定义:遗传算法是一种基于自然选择原理和自然遗传机制的搜索(寻优)算法,模拟自然界中的声明进化机制,在人工系统中实 ...

  2. C++算法之旅、06 基础篇 | 第三章 图论

    常用代码模板3--搜索与图论 - AcWing DFS 尽可能往深处搜,遇到叶子节点(无路可走)回溯,恢复现场继续走 数据结构:stack 空间:需要记住路径上的点,\(O(h)\). BFS使用空间 ...

  3. 如何get一个终身免费续期的定制数字人?

    想拥有一个"数字分身" 吗?给你一个终身免费续期的特权. 定制周期长?训练.运营成本高?成片效果生硬?无法应用于实际场景? 随着AIGC技术的快速发展,虚拟数字人的生成效率不断提高 ...

  4. strimzi实战之三:prometheus+grafana监控(按官方文档搞不定监控?不妨看看本文,已经踩过坑了)

    欢迎访问我的GitHub 这里分类和汇总了欣宸的全部原创(含配套源码):https://github.com/zq2599/blog_demos 本篇概览 由于整个系列的实战都涉及到消息生产和消费,所 ...

  5. destoon上做纯js实现html指定页面导出word

    因为最近做了范文网站需要,所以要下载为word文档,如果php进行处理,很吃后台服务器,所以想用前端进行实现.查询github发现,确实有这方面的插件. js导出word文档所需要的两个插件: 1 2 ...

  6. MySQL的index merge(索引合并)导致数据库死锁分析与解决方案

    背景 在DBS-集群列表-更多-连接查询-死锁中,看到9月22日有数据库死锁日志,后排查发现是因为mysql的优化-index merge(索引合并)导致数据库死锁. 定义 index merge(索 ...

  7. Intervals 题解

    Intervals 题目大意 给定 \(m\) 条形如 \((l_i,r_i,a_i)\) 的规则,你需要求出一个长为 \(n\) 的分数最大的 01 串的分数,其中一个 01 串 \(A\) 的分数 ...

  8. Go 函数多返回值错误处理与error 类型介绍

    Go 函数多返回值错误处理与error 类型介绍 目录 Go 函数多返回值错误处理与error 类型介绍 一.error 类型与错误值构造 1.1 Error 接口介绍 1.2 构造错误值的方法 1. ...

  9. 如何在 Vue.js 中引入原子设计?

    本文为翻译文章,原文链接: https://medium.com/@9haroon_dev/introducing-atomic-design-in-vue-js-a9e873637a3e 前言 原子 ...

  10. [ABC321C] 321-like Searcher

    Problem 题目简述 给你一个 \(K\),求出 \([1 \sim K]\) 区间内有多少个 321-like Number. 321-like Number 的定义: 每一位上的数字从左到右严 ...