pytorch的matmul怎么广播
1)如果是两个1维的,就向量内积;
2)如果两个都是2维的,就矩阵相乘
3)如果第一个是1维,第二个是2维;填充第一个使得能够和第二个参数相乘;如果第一个是2维,第二个是1维,就是矩阵和向量相乘;
例: a = torch.zeros(7,)
b= torch.zeros(7,8)
>>> torch.matmul(a,b)
tensor([0., 0., 0., 0., 0., 0., 0., 0.])
如果a是5,8维,报错。
例子2:
a = torch.zeros(8,)
b= torch.zeros(7,8)
>>> torch.matmul(b,a)
tensor([0., 0., 0., 0., 0., 0., 0.])
如果a是5,7维,报错。
4) 高维情况
i)其中一个1维,另一个N维(N>2): 类似3),即需要靠近的那个维数相同,比如(7,)和(7,8,11),又比如(7,8,11)和(11,)
ii)都高于2维,那么除掉最后两个维度之外(两个数的维数满足矩阵乘法,即,(m,p)和(p,n)),剩下的纬度满足pytorch的广播机制。比如:
# can line up trailing dimensions
>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty( 3,1,1)
# x and y are broadcastable.
# 1st trailing dimension: both have size 1
# 2nd trailing dimension: y has size 1
# 3rd trailing dimension: x size == y size
# 4th trailing dimension: y dimension doesn't exist
https://pytorch.org/docs/stable/notes/broadcasting.html#broadcasting-semantics https://pytorch.org/docs/stable/torch.html?highlight=matmul#torch.matmultorch.matmul(tensor1, tensor2, out=None) → Tensor: # can line up trailing dimensions
>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty( 3,1,1)
# x and y are broadcastable.
# 1st trailing dimension: both have size 1
# 2nd trailing dimension: y has size 1
# 3rd trailing dimension: x size == y size
# 4th trailing dimension: y dimension doesn't exist ###########################################################################需要除掉两个数的后两个维度,然后再看是否满足广播机制torch.matmul(tensor1, tensor2, out=None) → Tensor
Matrix product of two tensors.
The behavior depends on the dimensionality of the tensors as follows:
If both tensors are 1-dimensional, the dot product (scalar) is returned.
If both arguments are 2-dimensional, the matrix-matrix product is returned.
If the first argument is 1-dimensional and the second argument is 2-dimensional, a 1 is prepended to its dimension for the purpose of the matrix multiply. After the matrix multiply, the prepended dimension is removed.
If the first argument is 2-dimensional and the second argument is 1-dimensional, the matrix-vector product is returned.
If both arguments are at least 1-dimensional and at least one argument is N-dimensional (where N > 2), then a batched matrix multiply is returned. If the first argument is 1-dimensional, a 1 is prepended to its dimension for the purpose of the batched matrix multiply and removed after. If the second argument is 1-dimensional, a 1 is appended to its dimension for the purpose of the batched matrix multiple and removed after. The non-matrix (i.e. batch) dimensions are broadcasted (and thus must be broadcastable). For example, if
tensor1is a (j \times 1 \times n \times m)(j×1×n×m)tensor andtensor2is a (k \times m \times p)(k×m×p) tensor,outwill be an (j \times k \times n \times p)(j×k×n×p) tensor.
pytorch的matmul怎么广播的更多相关文章
- TensorFlow+TVM优化NMT神经机器翻译
TensorFlow+TVM优化NMT神经机器翻译 背景 神经机器翻译(NMT)是一种自动化的端到端方法,具有克服传统基于短语的翻译系统中的弱点的潜力.本文为全球电子商务部署NMT服务. 目前,将Tr ...
- [深度学习] pytorch学习笔记(1)(数据类型、基础使用、自动求导、矩阵操作、维度变换、广播、拼接拆分、基本运算、范数、argmax、矩阵比较、where、gather)
一.Pytorch安装 安装cuda和cudnn,例如cuda10,cudnn7.5 官网下载torch:https://pytorch.org/ 选择下载相应版本的torch 和torchvisio ...
- PyTorch 中的乘法:mul()、multiply()、matmul()、mm()、mv()、dot()
torch.mul() 函数功能:逐个对 input 和 other 中对应的元素相乘. 本操作支持广播,因此 input 和 other 均可以是张量或者数字. 举例如下: >>> ...
- PyTorch 中 torch.matmul() 函数的文档详解
官方文档 torch.matmul() 函数几乎可以用于所有矩阵/向量相乘的情况,其乘法规则视参与乘法的两个张量的维度而定. 关于 PyTorch 中的其他乘法函数可以看这篇博文,有助于下面各种乘法的 ...
- PyTorch 广播机制
PyTorch 广播机制 定义 PyTorch的tensor参数可以自动扩展其大小.一般的是小一点的会变大,来满足运算需求. 规则 满足一下情况的tensor是可以广播的. 至少有一个维度 两个ten ...
- 『PyTorch』第五弹_深入理解Tensor对象_中下:数学计算以及numpy比较_&_广播原理简介
一.简单数学操作 1.逐元素操作 t.clamp(a,min=2,max=4)近似于tf.clip_by_value(A, min, max),修剪值域. a = t.arange(0,6).view ...
- 【Pytorch】关于torch.matmul和torch.bmm的输出tensor数值不一致问题
发现 对于torch.matmul和torch.bmm,都能实现对于batch的矩阵乘法: a = torch.rand((2,3,10))b = torch.rand((2,2,10))### ma ...
- pytorch & numpy广播法则
广播法则 所有数组向维度最高的数组看齐,若维度不足则在最前面的维度用1补齐 扩展维度后,所有数组在某一维度相同或者长度为1,否则不能计算 当可以计算时,将长度为1的维度扩展为另一数组相应维度的长度 a ...
- 『PyTorch × TensorFlow』第十七弹_ResNet快速实现
『TensorFlow』读书笔记_ResNet_V2 对比之前的复杂版本,这次的torch实现其实简单了不少,不过这和上面的代码实现逻辑过于复杂也有关系. 一.PyTorch实现 # Author : ...
随机推荐
- python部署到服务器(1) 一一 搭建环境
本机环境说明 linux下的CentOS 7, 自带python2.7.5, 使用 python --version 命令查看,因系统需要python2.7.5,因此我们并不卸载,另外安装python ...
- 蓝牙App漏洞系列分析之三CVE-2017-0645
蓝牙App漏洞系列分析之三CVE-2017-0645 0x01 漏洞简介 Android 6月的安全公告,同时还修复了我们发现的一个蓝牙 App 提权中危漏洞,该漏洞允许手机本地无权限的恶意程序构造一 ...
- datePicker 及 timePicker 监听事件 获取用户选择 年月日分秒信息
public class MainActivity extends AppCompatActivity { private TimePicker timePicker; private DatePic ...
- Jumpserver1.4.1安装
第1章 CentOS环境准备 Jumpserver官网: http://docs.jumpserver.org/zh/docs/step_by_step.html 测试推荐硬件 CPU: 64位双核处 ...
- 【树形dp 思维题】HHHOJ#483. NOIP司马懿
要注意利用一些题目的特殊条件吧. 题目大意 有一颗$n$个点带点权$a_i$的树,$q$次询问树上是否存在长度为$l$的路径. $n,q,l\le 10^5,0 \le a_i \le 2$ 题目分析 ...
- .npy文件怎么打开
import numpy as np test = np.load(r'C:\Users\SAM\PycharmProjects\TEAMWORK\Preprocess_3D\muchdata-50- ...
- vue等诸多概念记录
讲的很好,转载记录下,转载自: https://www.cnblogs.com/taowd/p/11808710.html vue学习笔记-遗留问题记录 Node.js是什么?对node.js的理解 ...
- js获取服务器端时间
第一种: $.ajax({ type:"OPTIONS", url:"/", complete:function(x){ var date = x.getRes ...
- wind本地MySQL数据到hive的指定路径,Could not create file
一:使用:kettle:wind本地MySQL数据到hive的指定路径二:问题:没有root写权限网上说的什么少jar包,我这里不存在这种情况,因为我自己是导入jar包的:mysql-connecto ...
- BZOJ 1406: [AHOI2007]密码箱 exgcd+唯一分解定理
推出来了一个解法,但是感觉复杂度十分玄学,没想到秒过~ Code: #include <bits/stdc++.h> #define ll long long #define N 5000 ...