官方文档

torch.matmul() 函数几乎可以用于所有矩阵/向量相乘的情况,其乘法规则视参与乘法的两个张量的维度而定。

关于 PyTorch 中的其他乘法函数可以看这篇博文,有助于下面各种乘法的理解。

torch.matmul() 将两个张量相乘划分成了五种情形:一维 × 一维、二维 × 二维、一维 × 二维、二维 × 一维、涉及到三维及三维以上维度的张量的乘法。

以下是五种情形的详细解释:

  1. 如果两个张量都是一维的,即 torch.Size([n]) ,此时返回两个向量的点积。作用与 torch.dot() 相同,同样要求两个一维张量的元素个数相同。

    例如:

    >>> vec1 = torch.tensor([1, 2, 3])
    >>> vec2 = torch.tensor([2, 3, 4])
    >>> torch.matmul(vec1, vec2)
    tensor(20)
    >>> torch.dot(vec1, vec2)
    tensor(20) # 两个一维张量的元素个数要相同!
    >>> vec1 = torch.tensor([1, 2, 3])
    >>> vec2 = torch.tensor([2, 3, 4, 5])
    >>> torch.matmul(vec1, vec2)
    Traceback (most recent call last):
    File "<stdin>", line 1, in <module>
    RuntimeError: inconsistent tensor size, expected tensor [3] and src [4] to have the same number of elements, but got 3 and 4 elements respectively
  2. 如果两个参数都是二维张量,那么将返回矩阵乘积。作用与 torch.mm() 相同,同样要求两个张量的形状需要满足矩阵乘法的条件,即(n×m)×(m×p)=(n×p)

    例如:

    >>> arg1 = torch.tensor([[1, 2], [3, 4]])
    >>> arg1
    tensor([[1, 2],
    [3, 4]])
    >>> arg2 = torch.tensor([[-1], [2]])
    >>> arg2
    tensor([[-1],
    [ 2]])
    >>> torch.matmul(arg1, arg2)
    tensor([[3],
    [5]])
    >>> torch.mm(arg1, arg2)
    tensor([[3],
    [5]]) >>> arg2 = torch.tensor([[-1], [2], [1]])
    >>> torch.matmul(arg1, arg2) # 要求满足矩阵乘法的条件
    Traceback (most recent call last):
    File "<stdin>", line 1, in <module>
    RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x2 and 3x1)
  3. 如果第一个参数是一维张量,第二个参数是二维张量,那么在一维张量的前面增加一个维度,然后进行矩阵乘法,矩阵乘法结束后移除添加的维度。文档原文为:“a 1 is prepended to its dimension for the purpose of the matrix multiply. After the matrix multiply, the prepended dimension is removed.”

    例如:

    >>> arg1 = torch.tensor([-1, 2])
    >>> arg2 = torch.tensor([[1, 2], [3, 4]])
    >>> torch.matmul(arg1, arg2)
    tensor([5, 6]) >>> arg1 = torch.unsqueeze(arg1, 0) # 在一维张量前增加一个维度
    >>> arg1.shape
    torch.Size([1, 2])
    >>> ans = torch.mm(arg1, arg2) # 进行矩阵乘法
    >>> ans
    tensor([[5, 6]])
    >>> ans = torch.squeeze(ans, 0) # 移除增加的维度
    >>> ans
    tensor([5, 6])
  4. 如果第一个参数是二维张量(矩阵),第二个参数是一维张量(向量),那么将返回矩阵×向量的积。作用与 torch.mv() 相同。另外要求矩阵的形状和向量的形状满足矩阵乘法的要求。

    例如:

    >>> arg1 = torch.tensor([[1, 2], [3, 4]])
    >>> arg2 = torch.tensor([-1, 2])
    >>> torch.matmul(arg1, arg2)
    tensor([3, 5]) >>> torch.mv(arg1, arg2)
    tensor([3, 5])
  5. 如果两个参数均至少为一维,且其中一个参数的 ndim > 2,那么……(一番处理),然后进行批量矩阵乘法。

    这条规则将所有涉及到三维张量及三维以上的张量(下文称为高维张量)的乘法分为三类:一维张量 × 高维张量、高维张量 × 一维张量、二维及二维以上的张量 × 二维及二维以上的张量。

    1. 如果第一个参数是一维张量,那么在此张量之前增加一个维度。

      文档原文为:“ If the first argument is 1-dimensional, a 1 is prepended to its dimension for the purpose of the batched matrix multiply and removed after.”

    2. 如果第二个参数是一维张量,那么在此张量之后增加一个维度。

      文档原文为:“If the second argument is 1-dimensional, a 1 is appended to its dimension for the purpose of the batched matrix multiple and removed after. ”

    3. 由于上述两个规则,所有涉及到一维张量和高维张量的乘法都被转变为二维及二维以上的张量 × 二维及二维以上的张量。

      然后除掉最右边的两个维度,对剩下的维度进行广播。原文为:“The non-matrix dimensions are broadcasted.”

      然后就可以进行批量矩阵乘法。

      For example, if input is a (j × 1 × n × n) tensor and other is a (k × n × n) tensor, out will be a (j × k × n × n) tensor.

    举例如下:

    >>> arg1 = torch.tensor([1, 2, -1, 1])
    >>> arg2 = torch.randint(low=-2, high=3, size=[3, 4, 1])
    >>> torch.matmul(arg1, arg2)
    tensor([[ 5],
    [-1],
    [-1]]) >>> arg2
    tensor([[[ 2],
    [ 2],
    [-1],
    [-2]], [[-2],
    [ 2],
    [ 1],
    [-2]], [[ 0],
    [ 0],
    [-1],
    [-2]]])

    根据第一条规则,先对 arg1 增加维度:

    >>> arg3 = torch.unsqueeze(arg1, 0)
    >>> arg3
    tensor([[ 1, 2, -1, 1]])
    >>> arg3.shape
    torch.Size([1, 4])

    由于 arg2.shape=torch.Size([3, 4, 1]) ,根据广播的规则,arg3 要被广播为 torch.Size([3, 1, 4]) ,也就是下面的 arg4

    >>> arg4 = torch.tensor([ [[ 1,  2, -1,  1]], [[ 1,  2, -1,  1]], [[ 1,  2, -1,  1]] ])
    >>> arg4
    tensor([[[ 1, 2, -1, 1]], [[ 1, 2, -1, 1]], [[ 1, 2, -1, 1]]])
    >>> arg4.shape
    torch.Size([3, 1, 4])

    最后我们使用乘法函数 torch.bmm() 来进行批量矩阵乘法:

    >>> torch.bmm(arg4, arg2)
    tensor([[[ 5]], [[-1]], [[-1]]])

    由于在第一条规则中对一维张量增加了维度,因此矩阵计算结束后要移除这个维度。移除之后和前面使用 torch.matmul() 的结果相同!

PS:在看文档第五条规则时,起先也非常不明白,试了很多次高维和一维的张量乘法总是提示RuntimeError: mat1 and mat2 shapes cannot be multiplied,然后就尝试理解这条规则。因为这条规则很长,分成了三个小情形,并且这三个情形并不是一一独立的,而是前两个情形经过处理之后最后全都可以转变成第三个情形。另一个理解的突破口是 prependedappended 这两个单词,通过它们的前缀可以猜测出:一个是在张量前面增加维度,一个是在张量后面增加维度,然后广播再进行批量矩阵乘法就验证出来了!

PyTorch 中 torch.matmul() 函数的文档详解的更多相关文章

  1. 在MyEclipse中使用javadoc导出API文档详解

    本篇文档介绍如何在MyEclipse中导出javadoc(API)帮助文档,并且使用htmlhelp.exe和jd2chm.exe生成chm文档. 具体步骤如下: 打开MyEclipse,选中想要制作 ...

  2. MYSQL服务器my.cnf配置文档详解

    MYSQL服务器my.cnf配置文档详解 硬件:内存16G [client] port = 3306 socket = /data/3306/mysql.sock [mysql] no-auto-re ...

  3. 【红外DDE算法】数字细节增强算法的缘由与效果(我对FLIR文档详解)

    [红外DDE算法]数字细节增强算法的缘由与效果(我对FLIR文档详解) 1. 为什么红外系统中图像大多是14bit(甚至更高)?一个红外系统的性能经常以其探测的范围来区别,以及其对最小等效温差指标.首 ...

  4. Log4Net(二)之记录日志到文档详解

    原创文章,转载必需注明出处:http://www.ncloud.hk/%E6%8A%80%E6%9C%AF%E5%88%86%E4%BA%AB/log4net-%E4%BA%8C-%E4%B9%8B% ...

  5. Hibernate配置文档详解

    Hibernate配置文档有框架总部署文档hibernate.cfg.xml 和映射类的配置文档 ***.hbm.xml hibernate.cfg.xml(文件位置直接放在src源文件夹即可) (在 ...

  6. 【PDF】java使用Itext生成pdf文档--详解

    [API接口]  一.Itext简介 API地址:javadoc/index.html:如 D:/MyJAR/原JAR包/PDF/itext-5.5.3/itextpdf-5.5.3-javadoc/ ...

  7. elastic search文档详解

    在elastic search中文档(document)类似于关系型数据库里的记录(record),类型(type)类似于表(table),索引(index)类似于库(database). 文档一定有 ...

  8. 前端 HTML文档 详解

    <!DOCTYPE html> <html lang="en"> <head> <meta charset="UTF-8&quo ...

  9. ABBYY FineReader 15扫描和保存文档详解

    通过使用ABBYY FineReader 15 OCR文字识别软件的扫描和保存文档功能,用户可使用扫描仪或数码照相机获得图像文档,然后再转换为各种数字格式文档. 在"新任务窗口"中 ...

随机推荐

  1. 学习AJAX必知必会(1)~Ajax

    一.ajax(Asynchronous JavaScript And XML,即异步的 JS 和 XML) 1.通过 AJAX 可以在浏览器中向服务器发送异步请求实现无刷新获取数据. 2.优势:无刷新 ...

  2. 不难懂------适配移动端flexible

    基于 vue-cli 配置手淘的 lib-flexible + rem,实现移动端自适应 安装 flexible npm install lib-flexible --save 引入 flexible ...

  3. 【Android】安卓四大组件之内容提供者

    [Android]安卓四大组件之内容提供者 1.关于内容提供者 1.1 什么是内容提供者 内容提供者就是contentProvider,作用有如下: 给多个应用提供数据 类似一个接口 可以和多个应用分 ...

  4. AT3527 [ARC082D] Sandglass

    解法一 直接考虑在初始为 \(a\) 的情况下时刻 \(t\) 时 \(A\) 中剩余的沙子是行不通的,不妨反过来考虑在时刻 \(t\) 每个初始值 \(a\) 的答案,令其为 \(f_t(a)\). ...

  5. IDE集成git

    目录 简介 Git安装 IDE集成Git IDE集成Git代码的创建分享上传 代码的下载和普通上传 分子的创建以及合并 代码的回滚 查看历史版本 简介 Git 是一个开源的分布式版本控制软件,用以有效 ...

  6. android 安装gcc环境

    看到了一篇关于Android上利用终端来使用gcc编译C/C++源程序的文章,我感到无比兴奋,所以立刻将我自己的安装过程记下来.那个后记也很有用的. gcc编译源代码需要创建临时文件,而gcc又只能安 ...

  7. Java语言中的访问权限修饰符

    一个Java应用有很多类,但是有些类,并不希望被其他类使用.每个类中都有数据成员和方法成员,但是并不是每个数据和方法,都允许在其他类中调用.如何能做到访问控制呢?就需要使用访问权限修饰符. Java语 ...

  8. Saas系统架构的思考,多租户Saas架构设计分析

    ToB Saas系统最近几年都很火.很多创业公司都在尝试创建企业级别的应用 cRM, HR,销售, Desk Saas系统.很多Saas创业公司也拿了大额风投.毕竟Saas相对传统软件的优势非常明显. ...

  9. Annotation深入研究——@Documented注释使用

    Documented注释的作用及其javadoc文档生成工具的使用 代码放在MyDocumentedtAnnotationDemo.java文件中 package org.yu.demo16.docu ...

  10. 【HDU6687】Rikka with Stable Marriage(Trie树 贪心)

    题目链接 大意 给定\(A,B\)两个数组,让他们进行匹配. 我们称\(A_i\)与\(B_j\)的匹配是稳定的,当且仅当目前所剩元素不存在\(A_x\)或\(B_y\)使得 \(A_i\oplus ...