PyTorch基础

摘抄自《深度学习之Pytorch》。

Tensor(张量)

PyTorch里面处理的最基本的操作对象就是Tensor,表示的是一个多维矩阵,比如零维矩阵就是一个点,一维就是向量,二维就是一般的矩阵,多维就相当于一个多维数组,这和numpy是对应,而且PyTorch的Tensor可以和numpy的ndarray相互转换,唯一不同的是PyTorch可以在GPU上运行,而numpy的ndarray只能在CPU上运行。

常用的不同数据类型的Tensor有32位浮点型torch.FloatTensor64位浮点型torch.DoubleTensor16位整型torch.ShortTensor32位整型torch.IntTensor64位浮点型torch.LongTensor。torch.Tensor默认的是torch.FloatTensor数据类型。

Variable(变量)

Variable,变量,这个概念在numpy中是没有,是神经网络计算图里特有的一个概念,就是Variable提供了自动求导的功能。

Variable和Tensor本质上没有区别,不过Variable会被放入一个计算图中,然后进行前向传播、反向传播、自动求导。

首先Variable是在torch.autograd.Variable中,要将tensor变成Variable也非常简单,比如想让一个tensor a变成Variable,只需Variable(a)即可。

Variable有三个比较重要的组成属性:data,grad和grad_fn。通过data可以取出Variable里面的tensor数值;grad_fn表示的是得到这个Variable的操作,比如通过加减还是乘除得到的;grad是这个Variable的反向传播梯度。

Dataset(数据集)

在处理任何机器学习问题之前都需要数据读取,并进行预处理。PyTorch提供了很多工具使得数据的读取和预处理变得很容易。

torch.utils.data.Dataset是代表这一数据的抽象类。你可以自己定义你的数据类,继承和重写这个抽象类,非常简单,只需要定义__len__和__getitem__这个两个函数:

class myDataset(Dataset):
def __init__(self,csv_file,txt_file,root_dir,other_file):
self.csv_data = pd.read_csv(csv_file)
with open(txt_file,'r') as f:
data_list = f.readlines()
self.txt_data = data_list
self.root_dir = root_dir def __len__(self):
return len(self.csv_data) def __gettime__(self,idx):
data = (self.csv_data[idx],self.txt_data[idx])
return data

通过上面的方式,可以定义我们需要的数据类,可以同迭代的方式来获取每一个数据,但这样很难实现缺batch,shuffle或者是多线程去读取数据,所以PyTorch中提供了一个简单的办法来做这个事情,通过torch.utils.data.DataLoader来定义一个新的迭代器,如下:

dataiter = DataLoader(myDataset,batch_size=32,shuffle=True,collate_fn=defaulf_collate)

其中的参数都很清楚,只有collate_fn是标识如何去样本的,我们可以定义自己的函数来准确地实现想要的功能,默认的函数在一般情况下都是可以使用的。

nn.Module(模组)

在PyTorch里面编写神经网络,所有的层结构和损失函数都来自于torch.nn,所有的模型构建都是从这个基类nn.Module继承的,于是有了下面的这个模板。

class net_name(nn.Module):
def __init__(self,other_arguments):
super(net_name,self).__init__()
self.conv1 = nn.Conv2d(in_channels,out_channels,kernel_size)
# other network layer def forward(self,x):
x = self.conv1(x)
return x

这样就建立一个计算图,并且这个结构可以复用多次,每次调用就相当于用该计算图定义的相同参数做一次前向传播,得益于PyTorch的自动求导功能,所以我们不需要自己编写反向传播。

定义完模型之后,我们需要通过nn这个包来定义损失函数。常见的损失函数都已经定义在了nn中,比如均方误差、多分类的交叉熵以及二分类的交叉熵等等,调用这些已经定义好的的损失函数也很简单:

criterion = nn.CrossEntropyLoss()]
loss = criterion(output,target)

这样就能求得我们的输出和真实目标之间的损失函数了。

torch.optim(优化)

在机器学习或者深度学习中,我们需要通过修改参数使得损失函数最小化(或最大化),优化算法就是一种调整模型参数更新的策略。

优化算法分为两大类:

  1. 一阶优化算法

这种算法使用各个参数的梯度值来更新参数,最常用的一阶优化算法是梯度下降。所谓的梯度就是导数的多变量表达式,函数的梯度形成了一个向量场,同时也是一个方向,这个方向上方向导数最大,且等于梯度。梯度下降的功能是通过寻找最小值,控制方差,更新模型参数,最终使模型收敛,网络的参数更新公式如下:

\(\theta = \theta - \eta \times\frac{\sigma J(\theta)}{\sigma_\theta}\)

其中\(\eta\)是学习率,\(\frac{\sigma J(\theta)}{\sigma_\theta}\)是函数的梯度。这是深度学习里最常用的优化方法。

  1. 二阶优化算法

二阶优化算法是用来二阶导数(也叫做Hessian方法)来最小化或最大化损失函数,主要基于牛顿法,但由于二阶导数的计算成本很高,所以这种方法并没有广泛使用。torch.optim是一个实现各种优化算法的包,大多数常见的算法都能到直接通过这个包来调用,比如随机梯度下降,以及添加动量的随机梯度下降,自适应学习率等。在调用的时候需要优话传入的参数,这些参数都必须是Variable,然后传入一些基本的设定,比如学习率和动量等。

模型的保存和加载

在PyTorch中使用torch.save来保存模型的结构和参数,有两种保存方式:

  1. 保存整个模型的结构信息和参数信息,保存对象是模型model;
  2. 保存模型的参数,保存的对象是模型的状态model.state_dict()。

可以按如下方式保存:save的第一个参数是保存对象,第二个参数是保存路径及名称:

    torch.save(model,'./model.pth')
torch.save(model.state_dict(),'./model_state.pth')

加载模型有两种对应于保存模型的方式:

  1. 加载完整的模型结构和参数信息,使用 load_model = torch.load('model.pth'),在网络较大的时候加载的时间较长,同时存储空间也比较大;
  2. 加载模型参数信息,需要先导入模型的结构,然后通过 model.load_state_dic(torch.load('model_state.pth'))来导入。

[人工智能]Pytorch基础的更多相关文章

  1. 【新生学习】第一周:深度学习及pytorch基础

    DEADLINE: 2020-07-25 22:00 写在最前面: 本课程的主要思路还是要求大家大量练习 pytorch 代码,在写代码的过程中掌握深度学习的各类算法,希望大家能够坚持练习,相信经度过 ...

  2. pytorch基础学习(二)

    在神经网络训练时,还涉及到一些tricks,如网络权重的初始化方法,优化器种类(权重更新),图片预处理等,继续填坑. 1. 神经网络初始化(Network Initialization ) 1.1 初 ...

  3. PyTorch基础——词向量(Word Vector)技术

    一.介绍 内容 将接触现代 NLP 技术的基础:词向量技术. 第一个是构建一个简单的 N-Gram 语言模型,它可以根据 N 个历史词汇预测下一个单词,从而得到每一个单词的向量表示. 第二个将接触到现 ...

  4. pytorch 基础内容

    一些基础的操作: import torch as th a=th.rand(3,4) #随机数,维度为3,4的tensor b=th.rand(4)print(a)print(b) a+b tenso ...

  5. 饮冰三年-人工智能-Python-17Python基础之模块与包

    一.模块(modue) 简单理解一个.py文件就称之为一个模块. 1.1 模块种类: python标准库 第三方模板 应用程序自定义模块(尽量不要与内置函数重名) 1.2 模块导入方法 # impor ...

  6. Pytorch 基础

    Pytorch 1.0.0 学习笔记: Pytorch 的学习可以参考:Welcome to PyTorch Tutorials Pytorch 是什么? 快速上手 Pytorch! Tensors( ...

  7. pytorch基础教程1

    0.迅速入门:根据上一个博客先安装好,然后终端python进入,import torch ******************************************************* ...

  8. 【pytorch】pytorch基础学习

    目录 1. 前言 # 2. Deep Learning with PyTorch: A 60 Minute Blitz 2.1 base operations 2.2 train a classifi ...

  9. Pytorch基础(6)----参数初始化

    一.使用Numpy初始化:[直接对Tensor操作] 对Sequential模型的参数进行修改: import numpy as np import torch from torch import n ...

随机推荐

  1. upstream(负载均衡)

    一.什么是负载均衡 负载均衡,顾名思义是指将负载尽量均衡的分摊到多个不同的服务器,以保证服务的可用性和可靠性,提供给客户更好的用户体验: 负载均衡的直接目标就是尽量发挥多个服务单元的整体效能,要实现这 ...

  2. js 选中div中的文本

    function selectText(element) { var text = document.getElementById(element); if (document.body.create ...

  3. elasticsearch minhash 测试应用

    上一章看了代码实现,算是搞明白了各参数的意义,现在开始测试,为方便以ik分词示例(对elasticsearch支持较好,测试操作简单) 首先建index,自定义 analysis ik分词用 ik_s ...

  4. vim 高级技巧

    复制粘贴 normal 或v模式下 y/d/x 复制后,p来粘贴 编辑模式 默认的 set autoindent 会导致粘贴代码会导致缩进混乱 一则可以先关掉autoindent,二则可以先设置set ...

  5. 前端Js复习-前后台的搭建-结合Bootstrap和JQuery搭建vue项目

    流式布局思想 """ 页面的尺寸改变动态改变页面布局,或是通过父集标签控制多个子标签,这种布局思想就称之为 - 流式布局思想 1) 将标签宽高设置成 百分比,就可以随屏幕 ...

  6. Matlab高级教程_第三篇:Matlab转码C/C++方式(混编)_第二部分

    这一部分通过一些实例来进行转码和调试的讲解: 1. 输入变量.输出变量和过程内变量的内存预分配 函数代码:函数名test function [A,B] = test( mark,num,array ) ...

  7. [LC] 146. LRU Cache

    Design and implement a data structure for Least Recently Used (LRU) cache. It should support the fol ...

  8. RHEL安装神器EPEL

    什么是EPEL? EPEL的全称叫 Extra Packages for Enterprise Linux .EPEL是由 Fedora 社区打造,为 RHEL 及衍生发行版如 CentOS.Scie ...

  9. PostgreSQL中实现更新默认值(二)

    今天我们用表继承+触发器的方案,来实现表中的更新默认值.这也许是PostgreSQL里最佳的解决方案. 一. 创建一张表,作为父表 create table basic_update( t_updat ...

  10. 59)PHP,管理员表中所存在的项

    用户ID 用户名 用户密码 用户权限(就是他的角色等级,比如是1级  2级,  三级等等) 上次登录的IP 上次登录的时间