import torch

x = torch.randn(128, 20) # 输入的维度是(128,20)
m = torch.nn.Linear(20, 30) # 20,30是指维度
output = m(x)
print('m.weight.shape:\n ', m.weight.shape)
print('m.bias.shape:\n', m.bias.shape)
print('output.shape:\n', output.shape)

# ans = torch.mm(input,torch.t(m.weight))+m.bias 等价于下面的
ans = torch.mm(x, m.weight.t()) + m.bias
print('ans.shape:\n', ans.shape)

print(torch.equal(ans, output))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
m.weight.shape:
torch.Size([30, 20])
m.bias.shape:
torch.Size([30])
output.shape:
torch.Size([128, 30])
ans.shape:
torch.Size([128, 30])
True
1
2
3
4
5
6
7
8
9
为什么 m.weight.shape = (30,20)?

答:因为线性变换的公式是:

y=xAT+b y=xA^T+b
y=xA
T
+b

先生成一个(30,20)的weight,实际运算中再转置,这样就能和x做矩阵乘法了
---------------------
作者:m0_37586991
来源:CSDN
原文:https://blog.csdn.net/m0_37586991/article/details/87861418
版权声明:本文为博主原创文章,转载请附上博文链接!

torch.nn.Linear()函数的理解的更多相关文章

  1. [转载]Pytorch中nn.Linear module的理解

    [转载]Pytorch中nn.Linear module的理解 本文转载并援引全文纯粹是为了构建和分类自己的知识,方便自己未来的查找,没啥其他意思. 这个模块要实现的公式是:y=xAT+*b 来源:h ...

  2. 小白学习之pytorch框架(3)-模型训练三要素+torch.nn.Linear()

    模型训练的三要素:数据处理.损失函数.优化算法    数据处理(模块torch.utils.data) 从线性回归的的简洁实现-初始化模型参数(模块torch.nn.init)开始 from torc ...

  3. 关于torch.nn.Linear的笔记

    关于该类: torch.nn.Linear(in_features, out_features, bias=True) 可以对输入数据进行线性变换: $y  = x A^T + b$ in_featu ...

  4. pytorch中文文档-torch.nn常用函数-待添加-明天继续

    https://pytorch.org/docs/stable/nn.html 1)卷积层 class torch.nn.Conv2d(in_channels, out_channels, kerne ...

  5. torch.nn.LSTM()函数维度详解

    123456789101112lstm=nn.LSTM(input_size,                     hidden_size,                      num_la ...

  6. torch.nn.MSELoss()函数解读

    转载自:https://www.cnblogs.com/tingtin/p/13902325.html

  7. pytorch函数之nn.Linear

    class torch.nn.Linear(in_features,out_features,bias = True )[来源] 对传入数据应用线性变换:y = A x+ b 参数: in_featu ...

  8. PyTorch官方中文文档:torch.nn

    torch.nn Parameters class torch.nn.Parameter() 艾伯特(http://www.aibbt.com/)国内第一家人工智能门户,微信公众号:aibbtcom ...

  9. PyTorch里面的torch.nn.Parameter()

    在刷官方Tutorial的时候发现了一个用法self.v = torch.nn.Parameter(torch.FloatTensor(hidden_size)),看了官方教程里面的解释也是云里雾里, ...

随机推荐

  1. 从0开始学习Hadoop(2)安装JDK

    参考文档: 安装包方式安装:http://www.cnblogs.com/wuyudong/p/ubuntu-jdk8.html PPA方式安装:推荐  http://www.cnblogs.com/ ...

  2. Application 效能分析有妙招 — 使用 perf 走天下(转载)

    转载:http://tech.mozilla.com.tw/posts/1803/application-%E6%95%88%E8%83%BD%E5%88%86%E6%9E%90-%E4%BD%BF% ...

  3. zoj3955:Saddle Point(想法题)

    传送门 题意 给出n*m的矩阵,询问所有子矩阵中鞍点的个数 鞍点定义:在行唯一最小,在列唯一最大 分析 我们遍历每个点,计算该点对于答案的贡献即可. 每个点的贡献为\((2^{numa[i][j]}) ...

  4. CCF2016.4 - B题

    思路:创建两个bool数组来模拟下落过程,一个存放面板状态,一个存放下落的格子.检测格子和面板对应位置是否同时为True,如果是则有冲突,不能继续下落,否则增加行号.为了统一处理,我们把面板最下面加一 ...

  5. 比特币搬砖对冲策略Python源码

    策略复制地址:https://www.fmz.com/strategy/21023 策略原理 比特币搬砖策略是入门程序化交易的基础策略.原理简单,是新手尝试程序化的好选择,在其黄金时期,比特币搬砖也带 ...

  6. Intellij IDEA 快捷键整理(史上最全)

    [常规] Ctrl+Shift + Enter,语句完成 “!”,否定完成,输入表达式时按 “!”键 Ctrl+E,最近的文件 Ctrl+Shift+E,最近更改的文件 Shift+Click,可以关 ...

  7. iOS 计算字符串显示宽高度

    ObjC(Category of NSString): - (CGSize)getSizeWithFont:(UIFont *)font constrainedToSize:(CGSize)size{ ...

  8. 关于如何读取XML文件的一个简单方法

    在平时开发系统功能的时候,我们经常会碰到一些需求需要经常性的发生变化,比如 系统版本.更新日志 等等.这个时候用一个XML文件来替代数据库,就会变的简便很多. 前段时候我也正好需要改个需求,是关于客户 ...

  9. ssm(Spring、Springmvc、Mybatis)实战之淘淘商城-第三天(非原创)

    文章大纲 一.课程介绍二.简单功能实现三.图片上传功能实战四.项目源码与资料下载五.参考文章   一.课程介绍 一共14天课程(1)第一天:电商行业的背景.淘淘商城的介绍.搭建项目工程.Svn的使用. ...

  10. vue 数组和对象的双向绑定不响应问题

    对象和数组的数据类型是对象,对象是对象这个是毫无疑问的.数组可以把索引当成键名,把索引对应的元素当成该键名的键值. vue对象有些操作不能双向绑定的原因是vue未改变原对象,以及未给新增属性增加set ...