PyTorch使用总览

https://blog.csdn.net/u014380165/article/details/79222243

深度学习框架训练模型时的代码主要包含数据读取、网络构建和其他设置三方面,基本上掌握这三方面就可以较为灵活地使用框架训练模型。PyTorch是Facebook的官方深度学习框架之一,到现在开源1年时间,势头非常猛,相信使用过的人都会被其轻便和快速等特点深深吸引,因此这篇博客从整体上介绍如何使用PyTorch

PyTorch的官方github地址:https://github.com/pytorch/pytorch
PyTorch官方文档:http://pytorch.org/docs/0.3.0/

建议先看看:PyTorch学习之路(level1)——训练一个图像分类模型,对Pytorch的使用有一个快速的了解。

接下来就按照上述的3个方面来介绍如何使用PyTorch。

一、数据读取

数据读取部分包含如何将你的图像和标签数据转换成PyTorch框架的Tensor数据类型,官方代码库中有一个接口例子:torchvision.ImageFolder,这个接口在PyTorch学习之路(level1)——训练一个图像分类模型 中有简单介绍。因为这个接口针对的数据存放方式是每个文件夹包含一个类的图像,但是实际应用中可能你的数据不是这样维护的,或者你的数据是多标签的,或者其他更复杂的形式,那么就需要自定义一个数据读取接口,这个时候就不得不提一个PyTorch中数据读取基类:torch.utils.data.Dataset,包括前面提到的torchvision.ImageFolder接口的对应类也是继承torch.utils.data.Dataset实现的,因此torch.utils.data.Dataset类是PyTorch框架中数据读取的核心。那么如何自定义一个数据读取接口呢?可以看博客:PyTorch学习之路(level2)——自定义数据读取,这篇博客中从剖析torchvision.ImageFolder接口切入,然后引出如何自定义数据读取接口。这样就完成了数据的第一层封装。

在自定义数据读取接口时还有一步很重要的操作:数据预处理。常常我们在论文中看到的data argumentation就是指的数据预处理,对实验结果影响还是比较大的。该操作在PyTorch中可以通过torchvision.transforms接口来实现,具体请看博客:PyTorch源码解读之torchvision.transforms 的介绍。

经过上述的两个操作后,还需再进行一次封装,将数据和标签封装成数据迭代器,这样才方便模型训练的时候一个batch一个batch地进行,这就要用到torch.utils.data.DataLoader接口,该接口的一个输入就是前面继承自torch.utils.data.Dataset类的自定义了的对象(比如torchvision.ImageFolder类的对象),具体可以参考博客:PyTorch源码解读之torch.utils.data.DataLoader

至此,从图像和标签文件就生成了Tensor类型的数据迭代器,后续仅需将Tensor对象用torch.autograd.Variable接口封装成Variable类型(比如train_data=torch.autograd.Variable(train_data),如果要在gpu上运行则是:train_data=torch.autograd.Variable(train_data.cuda()))就可以作为模型的输入了。

其他自定义的数据读取接口例子可以参考:https://github.com/miraclewkf/MobileNetV2-PyTorch,该项目中的read_ImageNetData.py脚本自定义了读取ImageNet数据集的接口,训练数据的读取和验证数据的读取采取不同的接口实现,比较有特点。

二、网络构建

PyTorch框架中提供了一些方便使用的网络结构及预训练模型接口:torchvision.models,具体可以看博客:PyTorch源码解读之torchvision.models。该接口可以直接导入指定的网络结构,并且可以选择是否用预训练模型初始化导入的网络结构。

那么如何自定义网络结构呢?在PyTorch中,构建网络结构的类都是基于torch.nn.Module这个基类进行的,也就是说所有网络结构的构建都可以通过继承该类来实现,包括torchvision.models接口中的模型实现类也是继承这个基类进行重写的。自定义网络结构可以参考:1、https://github.com/miraclewkf/MobileNetV2-PyTorch。该项目中的MobileNetV2.py脚本自定义了网络结构。2、https://github.com/miraclewkf/SENet-PyTorch。该项目中的se_resnet.py和se_resnext.py脚本分别自定义了不同的网络结构。

如果要用某预训练模型为自定义的网络结构进行参数初始化,可以用torch.load接口导入预训练模型,然后调用自定义的网络结构对象的load_state_dict方式进行参数初始化,具体可以看https://github.com/miraclewkf/MobileNetV2-PyTorch项目中的train.py脚本中if args.resume条件语句。

三、其他设置

优化函数通过torch.optim包实现,比如torch.optim.SGD()接口表示随机梯度下降。更多优化函数可以看官方文档:http://pytorch.org/docs/0.3.0/optim.html

学习率策略通过torch.optim.lr_scheduler接口实现,比如torch.optim.lr_scheduler.StepLR()接口表示按指定epoch数减少学习率。更多学习率变化策略可以看官方文档:http://pytorch.org/docs/0.3.0/optim.html

损失函数通过torch.nn包实现,比如torch.nn.CrossEntropyLoss()接口表示交叉熵等。

多GPU训练通过torch.nn.DataParallel接口实现,比如:model = torch.nn.DataParallel(model, device_ids=[0,1])表示在gpu0和1上训练模型。

PyTorch使用总览的更多相关文章

  1. 【转载】PyTorch系列 (二):pytorch数据读取

    原文:https://likewind.top/2019/02/01/Pytorch-dataprocess/ Pytorch系列: PyTorch系列(一) - PyTorch使用总览 PyTorc ...

  2. 步步深入:MySQL架构总览->查询执行流程->SQL解析顺序

    前言: 一直是想知道一条SQL语句是怎么被执行的,它执行的顺序是怎样的,然后查看总结各方资料,就有了下面这一篇博文了. 本文将从MySQL总体架构--->查询执行流程--->语句执行顺序来 ...

  3. 基于Metronic的Bootstrap开发框架总览

    基于Metronic的Bootstrap开发框架经验总结(1)-框架总览及菜单模块的处理 最近一直很多事情,博客停下来好久没写了,整理下思路,把最近研究的基于Metronic的Bootstrap开发框 ...

  4. ES6入门系列三(特性总览下)

    0.导言 最近从coffee切换到js,代码量一下子变大了不少,也多了些许陌生感.为了在JS代码中,更合理的使用ES6的新特性,特在此对ES6的特性做一个简单的总览. 1.模块(Module) --C ...

  5. 基于Metronic的Bootstrap开发框架经验总结(1)-框架总览及菜单模块的处理

    最近一直很多事情,博客停下来好久没写了,整理下思路,把最近研究的基于Metronic的Bootstrap开发框架进行经验的总结出来和大家分享下,同时也记录自己对Bootstrap开发的学习研究的点点滴 ...

  6. 《zw版·Halcon-delphi系列原创教程》 Halcon分类函数·简明中文手册 总览

    <zw版·Halcon-delphi系列原创教程> Halcon分类函数·简明中文手册 总览 Halcon函数库非常庞大,光HALCONXLib_TLB.pas文件,源码就要7w多行,但核 ...

  7. Android数据的四种存储方式SharedPreferences、SQLite、Content Provider和File (一) —— 总览

    Android数据的四种存储方式SharedPreferences.SQLite.Content Provider和File (一) —— 总览   作为一个完成的应用程序,数据存储操作是必不可少的. ...

  8. 实战:ADFS3.0单点登录系列-总览

    本系列将以一个实际项目为背景,介绍如何使用ADFS3.0实现SSO.其中包括SharePoint,MVC,Exchange等应用程序的SSO集成. 整个系列将会由如下几个部分构成: 实战:ADFS3. ...

  9. Orchard官方文档翻译(一) 总览

    原文地址:http://docs.orchardproject.net/ 最近想要学习了解orchard,但却没有找到相关的中文文档,只有英文文档.于是决定自行翻译,以便日后方便翻阅. 转载请注明原作 ...

随机推荐

  1. 在谷歌安装扩展程序Axure RP Extension for Chrome后,经常无故损坏,无法使用

    最近因为要看需求给的原型图,但需求只给了html格式的文件,没有给可以在Axure软件里看的格式, 所以在谷歌安装了一个Axure RP Extension for Chrome扩展程序在谷歌浏览器看 ...

  2. Python的GUI编程(TK)

    TK在大多数 Unix平台.Windows平台和Macintosh系统都是预装好的,TKinter 模块是 Tk GUI 套件的标准Python接口.可实现Python的GUI编程. Tkinter模 ...

  3. SVN密码找回 完美方案

    问题背景 SVN(Subversion)版本管理工具.本文以Windows操作系统下使用SVN的场景. 长时间不使用SVN,可能会出现忘记了SVN密码的尴尬局面.那么,该如何找回SV密码呢? 处理思路 ...

  4. [Java Plasterer] Java Components 3:Java Enum

    Writer:BYSocket(泥沙砖瓦浆木匠) 微博:BYSocket 豆瓣:BYSocket Reprint it anywhere u want. Written In The Font Whe ...

  5. ZOJ Problem Set - 3708 Density of Power Network

    http://acm.zju.edu.cn/onlinejudge/showProblem.do?problemCode=3708 #include <stdio.h> #include ...

  6. java与json,一篇就够了

    本示例使用的json包为阿里的fastjson 首先写三个工具类(seter和geter方法省略,自行补上): /** * 屏幕实体类 */ public class Screen { private ...

  7. 07 训练Tensorflow识别手写数字

    打开Python Shell,输入以下代码: import tensorflow as tf from tensorflow.examples.tutorials.mnist import input ...

  8. 痞子衡嵌入式:SEGGER J-Link仿真器硬件版本变迁

    大家好,我是痞子衡,是正经搞技术的痞子.今天痞子衡给大家介绍的是J-Link仿真器版本变迁. 硬件版本 主控芯片 固件升级工具 V7 ARM7TDMI, 55MHz Atmel AT91SAM7S64 ...

  9. Hibernate学习(七)———— hibernate中查询方式详解

    序言 之前对hibernate中的查询总是搞混淆,不明白里面具体有哪些东西.就是因为缺少总结.在看这篇文章之前,你应该知道的是数据库的一些查询操作,多表查询等 --WH 一.hibernate中的5种 ...

  10. Spring Cloud Stream如何消费自己生产的消息?

    在上一篇<Spring Cloud Stream如何处理消息重复消费>中,我们通过消费组的配置解决了多实例部署情况下消息重复消费这一入门时的常见问题.本文将继续说说在另外一个被经常问到的问 ...