『PyTorch』矩阵乘法总结
1. 二维矩阵乘法 torch.mm()
torch.mm(mat1, mat2, out=None),其中mat1(\(n\times m\)),mat2(\(m\times d\)),输出out的维度是(\(n\times d\))。
该函数一般只用来计算两个二维矩阵的矩阵乘法,并且不支持broadcast操作。
2. 三维带batch的矩阵乘法 torch.bmm()
由于神经网络训练一般采用mini-batch,经常输入的时三维带batch的矩阵,所以提供torch.bmm(bmat1, bmat2, out=None),其中bmat1(\(b\times n \times m\)),bmat2(\(b\times m \times d\)),输出out的维度是(\(b \times n \times d\))。
该函数的两个输入必须是三维矩阵且第一维相同(表示Batch维度),不支持broadcast操作。
3. 多维矩阵乘法 torch.matmul()
torch.matmul(input, other, out=None)支持broadcast操作,使用起来比较复杂。
针对多维数据 matmul()乘法,我们可以认为该matmul()乘法使用使用两个参数的后两个维度来计算,其他的维度都可以认为是batch维度。假设两个输入的维度分别是input(\(1000 \times 500 \times 99 \times 11\)), other(\(500 \times 11 \times 99\))那么我们可以认为torch.matmul(input, other, out=None)乘法首先是进行后两位矩阵乘法得到\((99 \times 11) \times (11 \times 99)\Rightarrow(99 \times 99)\) ,然后分析两个参数的batch size分别是 \(( 1000 \times 500)\) 和 \(500\) , 可以广播成为 \((1000 \times 500)\), 因此最终输出的维度是(\(1000 \times 500 \times 99 \times 99\))。
4. 矩阵逐元素(Element-wise)乘法 torch.mul()
torch.mul(mat1, other, out=None),其中other乘数可以是标量,也可以是任意维度的矩阵,只要满足最终相乘是可以broadcast的即可
5. 两个运算符 @ 和 *
@:矩阵乘法,自动执行适合的矩阵乘法函数*:element-wise乘法
『PyTorch』矩阵乘法总结的更多相关文章
- 『PyTorch』第二弹重置_Tensor对象
『PyTorch』第二弹_张量 Tensor基础操作 简单的初始化 import torch as t Tensor基础操作 # 构建张量空间,不初始化 x = t.Tensor(5,3) x -2. ...
- 『PyTorch』第十二弹_nn.Module和nn.functional
大部分nn中的层class都有nn.function对应,其区别是: nn.Module实现的layer是由class Layer(nn.Module)定义的特殊类,会自动提取可学习参数nn.Para ...
- 『PyTorch』第九弹_前馈网络简化写法
『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下 在前面的例子中,基本上都是将每一层的输出直接作为下一层的 ...
- 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_下
『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上 # Author : Hellcat # Time : 2018/2/11 import torch as t import t ...
- 『PyTorch』第三弹重置_Variable对象
『PyTorch』第三弹_自动求导 torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现 Varibale包含三个属性: data ...
- 『PyTorch』第十弹_循环神经网络
RNN基础: 『cs231n』作业3问题1选讲_通过代码理解RNN&图像标注训练 TensorFlow RNN: 『TensotFlow』基础RNN网络分类问题 『TensotFlow』基础R ...
- 『PyTorch』第五弹_深入理解Tensor对象_下:从内存看Tensor
Tensor存储结构如下, 如图所示,实际上很可能多个信息区对应于同一个存储区,也就是上一节我们说到的,初始化或者普通索引时经常会有这种情况. 一.几种共享内存的情况 view a = t.arang ...
- 『PyTorch』第五弹_深入理解autograd_上:Variable属性方法
在PyTorch中计算图的特点可总结如下: autograd根据用户对variable的操作构建其计算图.对变量的操作抽象为Function. 对于那些不是任何函数(Function)的输出,由用户创 ...
- 『PyTorch』第四弹_通过LeNet初识pytorch神经网络_上
总结一下相关概念: torch.Tensor - 一个近似多维数组的数据结构 autograd.Variable - 改变Tensor并且记录下来操作的历史记录.和Tensor拥有相同的API,以及b ...
随机推荐
- noip33
T1 第一个猎人死的轮数等于在1号猎人之前死的猎人数+1,如果当前这个人没死,那么他死在一号猎人之前的概率为 \(\frac{w_{i}}{w_{1}+w_{i}}\),因为每死一个就会造成1的贡献, ...
- Git进行clone的时候,报错:remote: HTTP Basic: Access denied fatal: Authentication failed for ...
先执行: git config --system --unset credential.helper 原因:用户名或者密码错: 会提示让重新输入用户名和密码,输入正确的用户名和密码即可! 这样以后发现 ...
- Java 数组结构
数组是最常见的一种数据结构,是相同类型的.用一个标识符封装到一起的基本类型数据序列或对象序列.可以用一个统一的数组名和下标来唯一确定数组中的元素.实质上数组是一个简单的线性序列,因此数组访问起来很快. ...
- input text 只能输入数字 js 正则表达式
$("#txt1").keyup(function () { $(this).val($(this).val().replace(/[^0-9.]/g, '')); }).bind ...
- HCNP Routing&Switching之OSPF虚连接
前文我们了解了OSPF的网络类型.帧中继交换机映射以及路由器帧中继映射相关话题,回顾请参考https://www.cnblogs.com/qiuhom-1874/p/15195762.html:今天我 ...
- git推送文件到gitee
注册gitee账号 设置姓名.个人空间地址 点击头像旁边的加号,新建仓库 安装git # 设置姓名和邮箱,姓名是注册gitee时设置的姓名,邮箱是注册gitee的邮箱 git config --glo ...
- Ubuntu下安装Python3(与旧Python2版本共存)
官网下载Python3的源码 进行配置,在源码目录运行如下命令. ./configure --prefix=/usr/local/python3 --enable-shared 进行编译,在源码目录运 ...
- centos7系统上pgsql的一些报错解决方法
1.2021-07-15 # 问题: 登录时服务器拒绝连接 psql -h 192.168.1.112 # 解决方法:修改配置文件 pg_hba.conf ,将该主机加进白名单 vi pg_hba.c ...
- springboot通过AOP和自定义注解实现权限校验
自定义注解 PermissionCheck: package com.mgdd.sys.annotation; import java.lang.annotation.*; /** * @author ...
- uni-app中websocket的使用 断开重连、心跳机制
前言 最近关于H5和APP的开发中使用到了webSocket,由于web/app有时候会出现网络不稳定或者服务端主动断开,这时候导致消息推送不了的情况,需要客户端进行重连.查阅资料后发现了一个心跳机制 ...