学习pytorch路程之动手学深度学习-3.4-3.7

置信度、置信区间参考:https://cloud.tencent.com/developer/news/452418 本人感觉还是挺好理解的

交叉熵参考博客:https://www.cnblogs.com/kyrieng/p/8694705.html   https://blog.csdn.net/tsyccnh/article/details/79163834  个人感觉还不错,好理解

(这段瞅瞅就行了)torchvision包,服务于PyTorch深度学习框架的,用于构建计算机视觉模型,主要构成有:

  torchvision.datasets:加载数据的函数及常用的数据集接口

  torchvision.mdoels:包含常用的模型结构(含预训练模型),例如AlexNet、VGG、ResNet等

  torchvision.transforms:常用的图片变换,例如裁剪、旋转等

  torchvision.utils:其他一些有用的方法

一行里画出多张图像和对应标签的函数

def show_fashion_mnist(images, labels):
use_svg_display()
# 这里的_表示我们忽略(不使用)的变量
_, figs = plt.subplots(1, len(images), figsize=(12, 12))
  
  # subplots(row, cloumn, figsize=(12, 12)) row:子图的行数,即有几行;column:子图的列数,即一行有几个图;
  # figsize:子图的Height和Width 其他参数请参考博客https://www.cnblogs.com/zhouzhe-blog/p/9614761.html 感谢大佬
  # fig, ax = subplots() fig:返回一个图像fig(这是一个整幅图像,含有子图);ax:返回子图列表(fig的子图) for f, img, lbl in zip(figs, images, labels):
     # zip()把可迭代的对象中的元素打包成一个个的元组,返回这些元组的列表。若是可迭代对象的元素个数不一致,
     # 则返回列表的长度与最短的对象相同https://www.runoob.com/python/python-func-zip.html     
f.imshow(img.view((28, 28)).numpy()) # imshow(x,cmap) x表示要显示图片的变量,cmap为颜色图谱,默认为RGB(A)
# imshow()其他参数可百度
     f.set_title(lbl) # 设置图像标题
f.axes.get_xaxis().set_visible(False) # 设置x轴不可见
f.axes.get_yaxis().set_visible(False) # 设置y轴不可见
plt.show()

查看数据集前10个图片及对应标签

X, y = [], []
for i in range(10):
X.append(mnist_train[i][0]) # 每个图片对应的tensor(形状如torch.size([1, 28, 28]))附加到列表X中
y.append(mnist_train[i][1])  # 每个图像对应的labels的数字代表(即用0-9代替了标签)
show_fashion_mnist(X, get_fashion_mnist_labels(y))
# get_fashion_mnist_labels()通过数字获取对应的标签 show_fashion_mnist()显示图片和对应的label

softmax回归实现

  torch.gather()函数的理解

  下面说说我对gather()函数的理解,gather(input,dim,index,out=None),对于gather函数我百度了一些解释,参考了一些博客,如  https://www.jianshu.com/p/5d1f8cd5fe31  (对dim=0时的三维tensor解释的可以)   https://blog.csdn.net/edogawachia/article/details/80515038

以上述第一个博客的tensor为例,

input = tensor([[[18.,   5.,   7.,   1.,   1.],
     [ 3., 26., 9., 7., 9.],
     [ 10., 28., 22., 27., 0.]],      [[ 26., 10., 20., 29., 18.],
     [ 5., 24., 26., 21., 3.],
     [ 10., 29., 10., 0., 22.]]])
index = torch.LongTensor([[[0,1,2,0,2],
[0,0,0,0,0],
[1,1,1,1,1]],
[[1,2,2,2,2],
[0,0,0,0,0],
[2,2,2,2,2]]])
首先明确tensor的形状是(2,3,5),即第一层中括号里有2个维度,第二层中括号里有3个维度,第三层中括号里有5个维度,所以(2,3,5)
dim = 1
  即用index里值修改代替index对应位置的数值的下标中第2个下标值。例如index中下标为(0,0,0)的对应值0,用对应值0替换
  下标中第2个值0得出(0,0,0),然后去input中取下标为(0,0,0)的值18放入输出tensor下标(0,0,0)对应位置index中下标为(0,0,1)的对应
  值1,用对应值1替换下标中第2个值得出(0,1,1),取出input中下标为(0,1,1)的值为26放入输出tensor下标(0,0,1)对应位置index中下标为(1,0,1)
  对应值2,用对应值2替换下标中第2个值得出(1,2,1),取出input中下标为(1,2,1)的值为29放入输出tensor下标(1,0,1)对应位置。
  依次类推,得出输出tensor如下:
  tensor([[[ 18., 26., 22., 1., 0.],
        [ 18., 5., 7., 1., 1.],
        [ 3., 26., 9., 7., 9.]],
        [[ 5., 29., 10., 0., 22.],
        [ 26., 10., 20., 29., 18.],
        [ 10., 29., 10., 0., 22.]]])
dim = 2
  同理dim=1,这次只不过是替换下标中第3个值。如index中下标为(0,0,4)的对应值2,替换后得(0,0,2),取input值为7
index2 = torch.LongTensor([[[0,1,1,0,1],
[0,1,1,1,1],
[1,1,1,1,1]],
[[1,0,0,0,0],
[0,0,0,0,0],
[1,1,0,0,0]]])
dim = 0
  同理dim=1,这次是替换下标中第1个值。如index2中下标为(0,0,1)的对应值1,替换后的(1,0,1),取input值为10

torch.argmax()函数的理解

  我真是菜的一批,看了好久(一看就困),以博客 https://blog.csdn.net/weixin_42494287/article/details/92797061 的三维的例子为例,说一下个人的理解

a=torch.tensor([
[
[1, 5, 5, 2],
[9, -6, 2, 8],
[-3, 7, -9, 1]
],
[
[-1, 7, -5, 2],
[9, 6, 2, 8],
[3, 7, 9, 1]
]])
# a的形状为(2,3,4)比较时,若是第一块和第二块比,值一样,则取第二块的位置值;dim表示维度,要干掉的维度
b=torch.argmax(a,dim=0)
  dim=0,干掉第一维,也就是剩下(3,4),即红色中括号里对应位置数字相比,第一个红色中括号对应位置数字大,则取0,否则取1。
  值相同取第二个红色中括号对应值1(对应上面的黑色粗体字)。(为什么取0/1?)
  结果:
  tensor([[0, 1, 0, 1],
      [1, 1, 1, 1],
      [1, 1, 1, 1]])
c=torch.argmax(a,dim=1)
  dim=1,干掉第二维,也就是剩下(2,4),即蓝色和荧光绿色中括号里对应位置数字相对,蓝色剩一行,荧光绿色剩一行。以蓝色为例,第一个蓝色中括号
  数值大则取0,第二个蓝色中括号数值大则取1,第三个蓝色中括号数值大则取2。(为什么取0/1/2?)
  结果:
  tensor([[1, 2, 0, 1],
      [1, 2, 2, 1]])
d=torch.argmax(a,dim=2)
  dim=2,干掉第三维,也就是剩下(2,3),即红色中括号为2,蓝色和荧光绿色为3,也就是蓝色和荧光绿色中括号里剩一个值并转置(或者说取了值后放
  在一个列表里)。例如:第一个蓝色中括号:有2个5,取第二个5,取值2;第二个蓝色中括号:9最大,取值0;第三个蓝色中括号:7最大,取值1。组成列表为[2,0,1]
  结果:
  tensor([[2, 0, 1],
      [1, 0, 2]])
二维也一样,去掉某一维即可

看明白了argmax函数后,torch.sum(input,dim,output)也就懂了,不过sum是把要干掉的维度上数求和(即合并成一个)

torch.nn.CrossEntropyLoss()

  这个交叉熵损失函数把input tensor进行了softmax()、log()、NLLLoss()计算,即三合一。见博客 https://blog.csdn.net/qq_22210253/article/details/85229988 博主写的很好,感谢博主

小白学习之pytorch框架(4)-softmax回归(torch.gather()、torch.argmax()、torch.nn.CrossEntropyLoss())的更多相关文章

  1. 小白学习之pytorch框架(1)-torch.nn.Module+squeeze(unsqueeze)

    我学习pytorch框架不是从框架开始,从代码中看不懂的pytorch代码开始的 可能由于是小白的原因,个人不喜欢一些一下子粘贴老多行代码的博主或者一些弄了一堆概念,导致我更迷惑还增加了畏惧的情绪(个 ...

  2. 小白学习之pytorch框架(7)之实战Kaggle比赛:房价预测(K折交叉验证、*args、**kwargs)

    本篇博客代码来自于<动手学深度学习>pytorch版,也是代码较多,解释较少的一篇.不过好多方法在我以前的博客都有提,所以这次没提.还有一个原因是,这篇博客的代码,只要好好看看肯定能看懂( ...

  3. 小白学习之pytorch框架(6)-模型选择(K折交叉验证)、欠拟合、过拟合(权重衰减法(=L2范数正则化)、丢弃法)、正向传播、反向传播

    下面要说的基本都是<动手学深度学习>这本花书上的内容,图也采用的书上的 首先说的是训练误差(模型在训练数据集上表现出的误差)和泛化误差(模型在任意一个测试数据集样本上表现出的误差的期望) ...

  4. 小白学习之pytorch框架(5)-多层感知机(MLP)-(tensor、variable、计算图、ReLU()、sigmoid()、tanh())

    先记录一下一开始学习torch时未曾记录(也未好好弄懂哈)导致又忘记了的tensor.variable.计算图 计算图 计算图直白的来说,就是数学公式(也叫模型)用图表示,这个图即计算图.借用 htt ...

  5. 小白学习之pytorch框架(2)-动手学深度学习(begin-random.shuffle()、torch.index_select()、nn.Module、nn.Sequential())

    在这向大家推荐一本书-花书-动手学深度学习pytorch版,原书用的深度学习框架是MXNet,这个框架经过Gluon重新再封装,使用风格非常接近pytorch,但是由于pytorch越来越火,个人又比 ...

  6. 小白学习之pytorch框架(3)-模型训练三要素+torch.nn.Linear()

    模型训练的三要素:数据处理.损失函数.优化算法    数据处理(模块torch.utils.data) 从线性回归的的简洁实现-初始化模型参数(模块torch.nn.init)开始 from torc ...

  7. UFLDL深度学习笔记 (二)SoftMax 回归(矩阵化推导)

    UFLDL深度学习笔记 (二)Softmax 回归 本文为学习"UFLDL Softmax回归"的笔记与代码实现,文中略过了对代价函数求偏导的过程,本篇笔记主要补充求偏导步骤的详细 ...

  8. 从头学pytorch(四) softmax回归实现

    FashionMNIST数据集共70000个样本,60000个train,10000个test.共计10种类别. 通过如下方式下载. mnist_train = torchvision.dataset ...

  9. 家乐的深度学习笔记「4」 - softmax回归

    目录 softmax回归 分类问题 softmax回归模型 softmax运算 矢量表达式 单样本分类的矢量计算表达式 小批量样本分类的矢量计算表达式 交叉熵损失函数 模型预测及评价 图像分类数据集( ...

随机推荐

  1. 浏览器输入URL后,发生了什么?(打开一个网页会使用哪些协议?)

    简单说来分为6个过程 1.DNS解析 互联网每一个服务器都是以ip地址作为界限的,并不是我们平时输入的url地址,就比如www.ccc.com这一个网址,当输入进浏览器之后就会由DNS解析将它转化成一 ...

  2. 课堂测试_WEB界面链接数据库

    课堂测试_WEB界面链接数据库 一,题目: 一. 考试要求: 1登录账号:要求由6到12位字母.数字.下划线组成,只有字母可以开头:(1分) 2登录密码:要求显示“• ”或“*”表示输入位数,密码要求 ...

  3. zabbix监控linux 以及监控mysql

    Zabbix监控Linux主机设置方法 linux客户端 :59.128 安装了mysql 配置zabbix的yum源 rpm -ivh http://repo.zabbix.com/zabbix/2 ...

  4. 基于Qt5的排序算法简单可视化

    之前写了几个排序算法,然后看到别人将排序算法的过程可视化出来,所以就想尝试一下,然后就用Qt简单写了个界面,用QImage和QPainter来画图显示,代码比较简单. 我的想法是画图的时候,图像的X轴 ...

  5. BZOJ:1878: [SDOI2009]HH的项链

    题解:解法一:莫队 解法二:按区间左端点排序,让区间内最左边的贝壳对答案产生贡献,树状数组维护,转移对答案产生贡献的贝壳位置 #include<iostream> #include< ...

  6. PHP实现简易微信红包算法

    <?php /** * PHP实现简易的微信红包算法 * @version v1.0 * @author quetiezheng */ function getMoney($total, $pe ...

  7. .NET Core开发实战(第11课:文件配置提供程序)--学习笔记

    11 | 文件配置提供程序:自由选择配置的格式 文件配置提供程序 Microsoft.Extensions.Configuration.Ini Microsoft.Extensions.Configu ...

  8. kuangbin专题——简单搜索

    A - 棋盘问题 POJ - 1321 题意 在一个给定形状的棋盘(形状可能是不规则的)上面摆放棋子,棋子没有区别.要求摆放时任意的两个棋子不能放在棋盘中的同一行或者同一列,请编程求解对于给定形状和大 ...

  9. GPU 、APU、CUDA、TPU、FPGA介绍

    购买显卡主要关注:显存.带宽和浮点运算数量   GPU :图形处理器(英语:Graphics Processing Unit,缩写:GPU),又称显示核心.视觉处理器.显示芯片,是一种专门在个人电脑. ...

  10. 流程控制语句反汇编(1)(Debug版)

    // 流程控制语句反汇编 //Author:乾卦 Date:2014-5-8 #include<stdio.h> int main() { ,b=; if(a>b) { a=b; } ...