torch.nn.Linear()函数的理解
import torch
x = torch.randn(128, 20) # 输入的维度是(128,20)
m = torch.nn.Linear(20, 30) # 20,30是指维度
output = m(x)
print('m.weight.shape:\n ', m.weight.shape)
print('m.bias.shape:\n', m.bias.shape)
print('output.shape:\n', output.shape)
# ans = torch.mm(input,torch.t(m.weight))+m.bias 等价于下面的
ans = torch.mm(x, m.weight.t()) + m.bias
print('ans.shape:\n', ans.shape)
print(torch.equal(ans, output))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
m.weight.shape:
torch.Size([30, 20])
m.bias.shape:
torch.Size([30])
output.shape:
torch.Size([128, 30])
ans.shape:
torch.Size([128, 30])
True
1
2
3
4
5
6
7
8
9
为什么 m.weight.shape = (30,20)?
答:因为线性变换的公式是:
y=xAT+b y=xA^T+b
y=xA
T
+b
先生成一个(30,20)的weight,实际运算中再转置,这样就能和x做矩阵乘法了
---------------------
作者:m0_37586991
来源:CSDN
原文:https://blog.csdn.net/m0_37586991/article/details/87861418
版权声明:本文为博主原创文章,转载请附上博文链接!
torch.nn.Linear()函数的理解的更多相关文章
- [转载]Pytorch中nn.Linear module的理解
[转载]Pytorch中nn.Linear module的理解 本文转载并援引全文纯粹是为了构建和分类自己的知识,方便自己未来的查找,没啥其他意思. 这个模块要实现的公式是:y=xAT+*b 来源:h ...
- 小白学习之pytorch框架(3)-模型训练三要素+torch.nn.Linear()
模型训练的三要素:数据处理.损失函数.优化算法 数据处理(模块torch.utils.data) 从线性回归的的简洁实现-初始化模型参数(模块torch.nn.init)开始 from torc ...
- 关于torch.nn.Linear的笔记
关于该类: torch.nn.Linear(in_features, out_features, bias=True) 可以对输入数据进行线性变换: $y = x A^T + b$ in_featu ...
- pytorch中文文档-torch.nn常用函数-待添加-明天继续
https://pytorch.org/docs/stable/nn.html 1)卷积层 class torch.nn.Conv2d(in_channels, out_channels, kerne ...
- torch.nn.LSTM()函数维度详解
123456789101112lstm=nn.LSTM(input_size, hidden_size, num_la ...
- torch.nn.MSELoss()函数解读
转载自:https://www.cnblogs.com/tingtin/p/13902325.html
- pytorch函数之nn.Linear
class torch.nn.Linear(in_features,out_features,bias = True )[来源] 对传入数据应用线性变换:y = A x+ b 参数: in_featu ...
- PyTorch官方中文文档:torch.nn
torch.nn Parameters class torch.nn.Parameter() 艾伯特(http://www.aibbt.com/)国内第一家人工智能门户,微信公众号:aibbtcom ...
- PyTorch里面的torch.nn.Parameter()
在刷官方Tutorial的时候发现了一个用法self.v = torch.nn.Parameter(torch.FloatTensor(hidden_size)),看了官方教程里面的解释也是云里雾里, ...
随机推荐
- RDA 常用API
FAC: app_factory_data_new.h app_guiobj_cul_fm_factorySetting_new.c _APP_Update_Layer() //刷新节点 密码文件: ...
- HDU 5912 Fraction (模拟)
题意:给定一个分式,让你化简. 析:从分母开始模拟分数的算法,最后约分. 代码如下: #pragma comment(linker, "/STACK:1024000000,102400000 ...
- CodeForces 723A The New Year: Meeting Friends (水题)
题意:给定 3 个数,求其中一个数到另外两个数之间的最短距离. 析:很明显选中间那个点了. 代码如下: #pragma comment(linker, "/STACK:1024000000, ...
- ThinkPHP3.2.3学习笔记3---视图
一.说明 每个模块的模板文件是独立的,为了对模板文件更加有效的管理,ThinkPHP对模板文件进行目录划分,默认的模板文件定义规则是:视图目录/[模板主题/]控制器名/操作名+模板后缀 默认的视图目录 ...
- Windows 中属于不同Owner的Workspace 互相无法看见,且无法删除
Windows 中属于不同Owner的Workspace 互相无法看见,且无法删除 而且不能重叠,重叠的话会报错,告诉你这个文件夹已经被其他用户映射, 所以你必须以那个Owner登陆tfs,然后删除 ...
- Ubuntu 18.04 LTS 系统设置打不开了
在更换软件源后,安装了vim 和chrome,chrome很顺利,但是安装vim的时候后就显示有问题了,有的依赖版本不对,嫌版本太高,卸载后再装可以. 安装了python和python2 依旧是有些依 ...
- [转]POJ WA/RE指南
"POJ上头的题都是数学题",也不知道是那个家伙胡诌的--但是POJ的要求就是算法通过了也不让你AC.下面本人就这560题的经验,浅谈一下WA/RE了怎么办. 以下内容是扯淡-- ...
- [Usaco2013 Jan]Island Travels
Description Farmer John has taken the cows to a vacation out on the ocean! The cows are living on N ...
- ORACLE 如何查看存储过程的定义
ORACLE 如何查看存储过程的定义 相关的数据字典 USER_SOURCE 用户的存储过程.函数的源代码字典 DBA_SOURCE 整个系统所有用户的存储过程.函数的源代码字典 ALL_SOUR ...
- Kali linux 2016.2(Rolling)里安装中文输入法
写在前面的话 关于中文输入法,实在是有太多了.当然,你也不可以不安装,(安装了增强工具即可),在windows 里输入中文,复制进去即可. 但是呢,想成为高手,还是要学会安装和使用各版本的中文输入法. ...