[人工智能]Pytorch基础
PyTorch基础
摘抄自《深度学习之Pytorch》。
Tensor(张量)
PyTorch里面处理的最基本的操作对象就是Tensor,表示的是一个多维矩阵,比如零维矩阵就是一个点,一维就是向量,二维就是一般的矩阵,多维就相当于一个多维数组,这和numpy是对应,而且PyTorch的Tensor可以和numpy的ndarray相互转换,唯一不同的是PyTorch可以在GPU上运行,而numpy的ndarray只能在CPU上运行。
常用的不同数据类型的Tensor有32位浮点型torch.FloatTensor、64位浮点型torch.DoubleTensor、16位整型torch.ShortTensor、32位整型torch.IntTensor和64位浮点型torch.LongTensor。torch.Tensor默认的是torch.FloatTensor数据类型。
Variable(变量)
Variable,变量,这个概念在numpy中是没有,是神经网络计算图里特有的一个概念,就是Variable提供了自动求导的功能。
Variable和Tensor本质上没有区别,不过Variable会被放入一个计算图中,然后进行前向传播、反向传播、自动求导。
首先Variable是在torch.autograd.Variable中,要将tensor变成Variable也非常简单,比如想让一个tensor a变成Variable,只需Variable(a)即可。
Variable有三个比较重要的组成属性:data,grad和grad_fn。通过data可以取出Variable里面的tensor数值;grad_fn表示的是得到这个Variable的操作,比如通过加减还是乘除得到的;grad是这个Variable的反向传播梯度。
Dataset(数据集)
在处理任何机器学习问题之前都需要数据读取,并进行预处理。PyTorch提供了很多工具使得数据的读取和预处理变得很容易。
torch.utils.data.Dataset是代表这一数据的抽象类。你可以自己定义你的数据类,继承和重写这个抽象类,非常简单,只需要定义__len__和__getitem__这个两个函数:
class myDataset(Dataset):
def __init__(self,csv_file,txt_file,root_dir,other_file):
self.csv_data = pd.read_csv(csv_file)
with open(txt_file,'r') as f:
data_list = f.readlines()
self.txt_data = data_list
self.root_dir = root_dir
def __len__(self):
return len(self.csv_data)
def __gettime__(self,idx):
data = (self.csv_data[idx],self.txt_data[idx])
return data
通过上面的方式,可以定义我们需要的数据类,可以同迭代的方式来获取每一个数据,但这样很难实现缺batch,shuffle或者是多线程去读取数据,所以PyTorch中提供了一个简单的办法来做这个事情,通过torch.utils.data.DataLoader来定义一个新的迭代器,如下:
dataiter = DataLoader(myDataset,batch_size=32,shuffle=True,collate_fn=defaulf_collate)
其中的参数都很清楚,只有collate_fn是标识如何去样本的,我们可以定义自己的函数来准确地实现想要的功能,默认的函数在一般情况下都是可以使用的。
nn.Module(模组)
在PyTorch里面编写神经网络,所有的层结构和损失函数都来自于torch.nn,所有的模型构建都是从这个基类nn.Module继承的,于是有了下面的这个模板。
class net_name(nn.Module):
def __init__(self,other_arguments):
super(net_name,self).__init__()
self.conv1 = nn.Conv2d(in_channels,out_channels,kernel_size)
# other network layer
def forward(self,x):
x = self.conv1(x)
return x
这样就建立一个计算图,并且这个结构可以复用多次,每次调用就相当于用该计算图定义的相同参数做一次前向传播,得益于PyTorch的自动求导功能,所以我们不需要自己编写反向传播。
定义完模型之后,我们需要通过nn这个包来定义损失函数。常见的损失函数都已经定义在了nn中,比如均方误差、多分类的交叉熵以及二分类的交叉熵等等,调用这些已经定义好的的损失函数也很简单:
criterion = nn.CrossEntropyLoss()]
loss = criterion(output,target)
这样就能求得我们的输出和真实目标之间的损失函数了。
torch.optim(优化)
在机器学习或者深度学习中,我们需要通过修改参数使得损失函数最小化(或最大化),优化算法就是一种调整模型参数更新的策略。
优化算法分为两大类:
- 一阶优化算法
这种算法使用各个参数的梯度值来更新参数,最常用的一阶优化算法是梯度下降。所谓的梯度就是导数的多变量表达式,函数的梯度形成了一个向量场,同时也是一个方向,这个方向上方向导数最大,且等于梯度。梯度下降的功能是通过寻找最小值,控制方差,更新模型参数,最终使模型收敛,网络的参数更新公式如下:
\(\theta = \theta - \eta \times\frac{\sigma J(\theta)}{\sigma_\theta}\)
其中\(\eta\)是学习率,\(\frac{\sigma J(\theta)}{\sigma_\theta}\)是函数的梯度。这是深度学习里最常用的优化方法。
- 二阶优化算法
二阶优化算法是用来二阶导数(也叫做Hessian方法)来最小化或最大化损失函数,主要基于牛顿法,但由于二阶导数的计算成本很高,所以这种方法并没有广泛使用。torch.optim是一个实现各种优化算法的包,大多数常见的算法都能到直接通过这个包来调用,比如随机梯度下降,以及添加动量的随机梯度下降,自适应学习率等。在调用的时候需要优话传入的参数,这些参数都必须是Variable,然后传入一些基本的设定,比如学习率和动量等。
模型的保存和加载
在PyTorch中使用torch.save来保存模型的结构和参数,有两种保存方式:
- 保存整个模型的结构信息和参数信息,保存对象是模型model;
- 保存模型的参数,保存的对象是模型的状态model.state_dict()。
可以按如下方式保存:save的第一个参数是保存对象,第二个参数是保存路径及名称:
torch.save(model,'./model.pth')
torch.save(model.state_dict(),'./model_state.pth')
加载模型有两种对应于保存模型的方式:
- 加载完整的模型结构和参数信息,使用 load_model = torch.load('model.pth'),在网络较大的时候加载的时间较长,同时存储空间也比较大;
- 加载模型参数信息,需要先导入模型的结构,然后通过 model.load_state_dic(torch.load('model_state.pth'))来导入。
[人工智能]Pytorch基础的更多相关文章
- 【新生学习】第一周:深度学习及pytorch基础
DEADLINE: 2020-07-25 22:00 写在最前面: 本课程的主要思路还是要求大家大量练习 pytorch 代码,在写代码的过程中掌握深度学习的各类算法,希望大家能够坚持练习,相信经度过 ...
- pytorch基础学习(二)
在神经网络训练时,还涉及到一些tricks,如网络权重的初始化方法,优化器种类(权重更新),图片预处理等,继续填坑. 1. 神经网络初始化(Network Initialization ) 1.1 初 ...
- PyTorch基础——词向量(Word Vector)技术
一.介绍 内容 将接触现代 NLP 技术的基础:词向量技术. 第一个是构建一个简单的 N-Gram 语言模型,它可以根据 N 个历史词汇预测下一个单词,从而得到每一个单词的向量表示. 第二个将接触到现 ...
- pytorch 基础内容
一些基础的操作: import torch as th a=th.rand(3,4) #随机数,维度为3,4的tensor b=th.rand(4)print(a)print(b) a+b tenso ...
- 饮冰三年-人工智能-Python-17Python基础之模块与包
一.模块(modue) 简单理解一个.py文件就称之为一个模块. 1.1 模块种类: python标准库 第三方模板 应用程序自定义模块(尽量不要与内置函数重名) 1.2 模块导入方法 # impor ...
- Pytorch 基础
Pytorch 1.0.0 学习笔记: Pytorch 的学习可以参考:Welcome to PyTorch Tutorials Pytorch 是什么? 快速上手 Pytorch! Tensors( ...
- pytorch基础教程1
0.迅速入门:根据上一个博客先安装好,然后终端python进入,import torch ******************************************************* ...
- 【pytorch】pytorch基础学习
目录 1. 前言 # 2. Deep Learning with PyTorch: A 60 Minute Blitz 2.1 base operations 2.2 train a classifi ...
- Pytorch基础(6)----参数初始化
一.使用Numpy初始化:[直接对Tensor操作] 对Sequential模型的参数进行修改: import numpy as np import torch from torch import n ...
随机推荐
- Python笔记_第一篇_面向过程_第一部分_5.Python数据类型之列表类型(list)
Python中序列是最基本的数据结构.序列中的每个元素都分配一个数字(他的位置或者索引),第一个索引是0,第二个索引是1,依次类推.Python的列表数据类型类似于C语言中的数组,但是不同之处在于列表 ...
- 基于JSP+Servlet开发在线租车系统 java 源码
运行环境: 最好是java jdk 1.8,我们在这个平台上运行的.其他版本理论上也可以.IDE环境: Eclipse,Myeclipse,IDEA都可以tomcat环境: Tomcat 7.x,8. ...
- day58-mysql-视图,触发器
一. 视图 .1创建视图 create view p_view as select name,age from person; 视图的作用是隐藏数据,例如上面语句没有查询工资,是为了隐藏它,这样就避免 ...
- Python语言学习前提:Pycharm的使用
一.Pycharm的使用 1.点击Pycharm的图标 2.点击首页Create New Project > 在弹出的页面点击Pure Python 3.选择项目文件存放的位置,选择完成之后点击 ...
- 表查询语句及使用-连表(inner join-left join)-子查询
一.表的基本查询语句及方法 from. where. group by(分组).having(分组后的筛选).distinct(去重).order by(排序). limit(限制) 1.单表查询: ...
- 894A. QAQ#(暴力)
题目出处:http://codeforces.com/problemset/problem/894/A 题目大意:计数依次出现QAQ的次数 #include<iostream> using ...
- mpu6050的驱动的加载和测试步骤
mpu6050的使用方法: 1.接线方式: VCC,GND,SCL,SDA,正常接法,VCC接3.3v,主要说一下AD0引脚,用来表示地址 接低电平设备地址为0x68,接高电平表示0x69 2.设备接 ...
- 量化投资_轻松实现MATLAB蒙特卡洛方法建模
1 目录 * MATLAB随机数的产生 - Uniform,Normal & Custom distributions * 蒙特卡洛仿真 * 产生股票价格路径 * 期权定价 - 经典公式 - ...
- shell_跳板机推送公钥
#!/bin/bash#push publickey to aap-servers#将局域网内可以ping通的主机ip保存到一个文件> ip_up.txtfor i in {2..10}do { ...
- μC/OS-II中使用软件定时器
在试着将μC/OS-II移植到ARM7芯片(LPC2138)上的过程中,发现使用OSTmrCreate创建的OSTmr始终都不能执行CallbackFunction,OS版本是v2.85,最后是这么解 ...