PyTorch 中的乘法:mul()、multiply()、matmul()、mm()、mv()、dot()
torch.mul()
函数功能:逐个对 input
和 other
中对应的元素相乘。
本操作支持广播,因此 input
和 other
均可以是张量或者数字。
举例如下:
>>> import torch
>>> a = torch.randn(3)
>>> a
tensor([-1.7095, 1.7837, 1.1865])
>>> b = 2
>>> torch.mul(a, b)
tensor([-3.4190, 3.5675, 2.3730]) # 这里将 other 扩展成了 input 的形状
>>> a = 3
>>> b = torch.randn(3, 1)
>>> b
tensor([[-0.7705],
[ 1.1177],
[ 1.2447]])
>>> torch.mul(a, b)
tensor([[-2.3116],
[ 3.3530],
[ 3.7341]]) # 这里将 input 扩展成了 other 的形状
>>> a = torch.tensor([[2], [3]])
>>> a
tensor([[2],
[3]]) # a 是 2×1 的张量
>>> b = torch.tensor([-1, 2, 1])
>>> b
tensor([-1, 2, 1]) # b 是 1×3 的张量
>>> torch.mul(a, b)
tensor([[-2, 4, 2],
[-3, 6, 3]])
这个例子中,input
和 output
的形状都不是公共形状,因此两个都需要广播,都变成 2×3 的形状,然后再逐个元素相乘。
$$
\begin{gather}
\begin{pmatrix}
2 \
3
\end{pmatrix}
\Rightarrow
\begin{pmatrix}
2 & 2 & 2 \
3 & 3 & 3
\end{pmatrix}
\ , \
\begin{pmatrix}
-1 & 2 & 1
\end{pmatrix}
\Rightarrow
\begin{pmatrix}
-1 & 2 & 1 \
-1 & 2 & 1
\end{pmatrix}
\
\
\begin{pmatrix}
2 & 2 & 2 \
3 & 3 & 3
\end{pmatrix}
×
\begin{pmatrix}
-1 & 2 & 1 \
-1 & 2 & 1
\end{pmatrix}
\begin{pmatrix}
-2 & 4 & 2 \
-3 & 6 & 3
\end{pmatrix}
\end{gather}
$$
由上述例子可以看出,这种乘法是逐个对应元素相乘,因此 input
和 output
的前后顺序并不影响结果,即 torch.mul(a, b) =torch.mul(b, a)
。
torch.multiply()
torch.mul()
的别称。
torch.dot()
函数功能:计算 input
和 output
的点乘,此函数要求 input
和 output
都必须是一维的张量(其 shape 属性中只有一个值)!并且要求两者元素个数相同!
举例如下:
>>> torch.dot(torch.tensor([2, 3]), torch.tensor([2, 1]))
tensor(7)
>>> torch.dot(torch.tensor([2, 3]), torch.tensor([2, 1, 1])) # 要求两者元素个数相同
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
RuntimeError: inconsistent tensor size, expected tensor [2] and src [3] to have the same number of elements, but got 2 and 3 elements respectively
torch.mm()
函数功能:实现线性代数中的矩阵乘法(matrix multiplication):(n×m)
× (m×p)
= (n×p)
。
本函数不允许广播!
举例如下:
>>> mat1 = torch.randn(2, 3)
>>> mat2 = torch.randn(3, 2)
>>> torch.mm(mat1, mat2)
tensor([[-1.1846, -1.8327],
[ 0.8820, 0.0312]])
torch.mv()
函数功能:实现矩阵和向量(matrix × vector)的乘法,要求 input
的形状为 n×m
,output
为 torch.Size([m])
的一维 tensor。
举例如下:
>>> mat = torch.tensor([[1, 2, 3], [4, 5, 6]])
>>> mat
tensor([[1, 2, 3],
[4, 5, 6]])
>>> vec = torch.tensor([-1, 1, 2])
>>> vec
tensor([-1, 1, 2])
>>> mat.shape
torch.Size([2, 3])
>>> vec.shape
torch.Size([3])
>>> torch.mv(mat, vec)
tensor([ 7, 13])
注意,此函数要求第二个参数是一维 tensor,也即其 ndim
属性值为 1。这里我们要区分清楚张量的 shape
属性和 ndim
属性,前者表示张量的形状,后者表示张量的维度。(线性代数中二维矩阵的维度 m×n
通常理解为这里的形状)
对于 shape
值为 torch.Size([n])
和 torch.Size(1, n)
的张量,前者的 ndim=1
,后者的 ndim=2
,因此前者是可视为线代中的向量,后者可视为线代中的矩阵。
对于 shape
值为 torch.Size([1, n])
和 torch.Size([n, 1])
的张量,它们同样在 Pytorch
中被视为矩阵。例如:
>>> column = torch.tensor([[1], [2]])
>>> row = torch.tensor([3, 4])
>>> column.shape
torch.Size([2, 1]) # 矩阵
>>> row.shape
torch.Size([2]) # 一维张量
>>> matrix = torch.randn(1, 3)
>>> matrix.shape
torch.Size([1, 3]) # 矩阵
对于张量(以及线代中的向量和矩阵)的理解可看这篇博文。
torch.bmm()
函数功能:实现批量的矩阵乘法。
本函数要求 input
和 output
的 ndim
均为 3,且前者形状为 b×n×m
,后者形状为 b×m×p
。可以理解为 input
中包含 b
个形状为 n×m
的矩阵, output
中包含 b
个形状为 m×p
的矩阵,然后第一个 n×m
的矩阵 × 第一个 m×p
的矩阵得到第一个 n×p
的矩阵,第二个……,第 b
个……因此最终得到 b
个形状为 n×p
的矩阵,即最终结果是一个三维张量,形状为 b×n×p
。
举例如下:
>>> batch_matrix_1 = torch.tensor([ [[1, 2], [3, 4], [5, 6]] , [[-1, -2], [-3, -4], [-5, -6]] ])
>>> batch_matrix_1
tensor([[[ 1, 2],
[ 3, 4],
[ 5, 6]],
[[-1, -2],
[-3, -4],
[-5, -6]]])
>>> batch_matrix_1.shape
torch.Size([2, 3, 2])
>>> batch_matrix_2 = torch.tensor([ [[1, 2], [3, 4]], [[1, 2], [3, 4]] ])
>>> bat
batch_matrix_1 batch_matrix_2
>>> batch_matrix_2
tensor([[[1, 2],
[3, 4]],
[[1, 2],
[3, 4]]])
>>> batch_matrix_2.shape
torch.Size([2, 2, 2])
>>> torch.bmm(batch_matrix_1, batch_matrix_2)
tensor([[[ 7, 10],
[ 15, 22],
[ 23, 34]],
[[ -7, -10],
[-15, -22],
[-23, -34]]])
torch.matmul()
torch.matmul()
可以用于 PyTorch
中绝大多数的乘法,在不同的情形下,它与上述各个乘法函数起着相同的作用,具体请看这篇博文
PyTorch 中的乘法:mul()、multiply()、matmul()、mm()、mv()、dot()的更多相关文章
- PyTorch 中 torch.matmul() 函数的文档详解
官方文档 torch.matmul() 函数几乎可以用于所有矩阵/向量相乘的情况,其乘法规则视参与乘法的两个张量的维度而定. 关于 PyTorch 中的其他乘法函数可以看这篇博文,有助于下面各种乘法的 ...
- numpy 和tensorflow 中的乘法
矩阵乘法:tf.matmul() np.dot() ,@ 逐元素乘法:tf.multiply() np.multiply()
- Pytorch中的自编码(autoencoder)
Pytorch中的自编码(autoencoder) 本文资料来源:https://www.bilibili.com/video/av15997678/?p=25 什么是自编码 先压缩原数据.提取出最有 ...
- [转载]Pytorch中nn.Linear module的理解
[转载]Pytorch中nn.Linear module的理解 本文转载并援引全文纯粹是为了构建和分类自己的知识,方便自己未来的查找,没啥其他意思. 这个模块要实现的公式是:y=xAT+*b 来源:h ...
- PyTorch官方中文文档:PyTorch中文文档
PyTorch中文文档 PyTorch是使用GPU和CPU优化的深度学习张量库. 说明 自动求导机制 CUDA语义 扩展PyTorch 多进程最佳实践 序列化语义 Package参考 torch to ...
- PyTorch中ReLU的inplace
0 - inplace 在pytorch中,nn.ReLU(inplace=True)和nn.LeakyReLU(inplace=True)中存在inplace字段.该参数的inplace=True的 ...
- pytorch中tensorboardX的用法
在代码中改好存储Log的路径 命令行中输入 tensorboard --logdir /home/huihua/NewDisk1/PycharmProjects/pytorch-deeplab-xce ...
- Pytorch中RoI pooling layer的几种实现
Faster-RCNN论文中在RoI-Head网络中,将128个RoI区域对应的feature map进行截取,而后利用RoI pooling层输出7*7大小的feature map.在pytorch ...
- pytorch 中的重要模块化接口nn.Module
torch.nn 是专门为神经网络设计的模块化接口,nn构建于autgrad之上,可以用来定义和运行神经网络 nn.Module 是nn中重要的类,包含网络各层的定义,以及forward方法 对于自己 ...
随机推荐
- Spring 官宣发布 Spring Boot 3.0 第一个里程碑 M1,从 Java 8 提升到 Java 17!
Spring官方于2022年1月20日发布Spring Boot 3.0.0-M1版本,预示开启了Spring Boot 3.0的里程碑,相信这是通往下一代Spring框架的激动人心的旅程. 接下来一 ...
- 使用 ES Module 的正确姿势
前面我们在深入理解 ES Module 中详细介绍过 ES Module 的工作原理.目前,ES Module 已经在逐步得到各大浏览器厂商以及 NodeJS 的原生支持.像 vite 等新一代的构建 ...
- go get失败解决办法
go get时由于防火墙的原因,会导致失败.目前可以通过修改GOPROXY的方法解决该问题. 无论是在win下还是linux,macos下,只需要将环境变量GOPROXY设置成https://gopr ...
- Python小练习-购物商城(一部分代码,基于python2.7.5)
新手写作,用来练习与提高python编写.思考能力,有错误的地方请指正,谢谢! 第一次写博客,课题是一位大神的博客,本着练习的目的,就自己重写了一遍,有很多不足的地方,希望借博客记录下自己的成长: ...
- Tomcat-如何在IDEA启动部署web模板
IDEA部署工程到Tomcat上运行 1,建议修改web工程对应的Tomcat运行实例名称 2,将需要部署的web工程添加到Tomcat运行实例中,添加或删除 Application context: ...
- linux系统——Redis集群搭建(主从+哨兵模式)
趁着这几天刚好有点空,就来写一下redis的集群搭建,我跟大家先说明,本文的redis集群因为linux服务器只是阿里云一台服务器,所以集群是redis启动不同端口,但是也能达到集群的要求.其实不同服 ...
- spring 异常处理的方式?
一.使用SimpleMappingExceptionResolver解析器 1.1在mvc中进行 配置. <?xml version="1.0" encoding=" ...
- 从服务之间的调用来看 我们为什么需要Dapr
Dapr 相关的文章我已经写了20多篇了[1] . 当向其他人推荐Dapr 的时候,需要回答的一个问题就是: Dapr 似乎并不是特别令人印象深刻.它提供了一组"构建块",解决了与 ...
- 计算机电子书 2020 CDNDrive 备份(预览版 II)
下载方式 pip install CDNDrive # 或 # pip install git+https://github.com/apachecn/CDNDrive cdrive download ...
- bind方法源码
'use strict'; module.exports = function bind(fn, thisArg) { return function wrap() { var args = new ...