torch.nn.Linear()函数的理解
import torch
x = torch.randn(128, 20) # 输入的维度是(128,20)
m = torch.nn.Linear(20, 30) # 20,30是指维度
output = m(x)
print('m.weight.shape:\n ', m.weight.shape)
print('m.bias.shape:\n', m.bias.shape)
print('output.shape:\n', output.shape)
# ans = torch.mm(input,torch.t(m.weight))+m.bias 等价于下面的
ans = torch.mm(x, m.weight.t()) + m.bias
print('ans.shape:\n', ans.shape)
print(torch.equal(ans, output))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
m.weight.shape:
torch.Size([30, 20])
m.bias.shape:
torch.Size([30])
output.shape:
torch.Size([128, 30])
ans.shape:
torch.Size([128, 30])
True
1
2
3
4
5
6
7
8
9
为什么 m.weight.shape = (30,20)?
答:因为线性变换的公式是:
y=xAT+b y=xA^T+b
y=xA
T
+b
先生成一个(30,20)的weight,实际运算中再转置,这样就能和x做矩阵乘法了
---------------------
作者:m0_37586991
来源:CSDN
原文:https://blog.csdn.net/m0_37586991/article/details/87861418
版权声明:本文为博主原创文章,转载请附上博文链接!
torch.nn.Linear()函数的理解的更多相关文章
- [转载]Pytorch中nn.Linear module的理解
[转载]Pytorch中nn.Linear module的理解 本文转载并援引全文纯粹是为了构建和分类自己的知识,方便自己未来的查找,没啥其他意思. 这个模块要实现的公式是:y=xAT+*b 来源:h ...
- 小白学习之pytorch框架(3)-模型训练三要素+torch.nn.Linear()
模型训练的三要素:数据处理.损失函数.优化算法 数据处理(模块torch.utils.data) 从线性回归的的简洁实现-初始化模型参数(模块torch.nn.init)开始 from torc ...
- 关于torch.nn.Linear的笔记
关于该类: torch.nn.Linear(in_features, out_features, bias=True) 可以对输入数据进行线性变换: $y = x A^T + b$ in_featu ...
- pytorch中文文档-torch.nn常用函数-待添加-明天继续
https://pytorch.org/docs/stable/nn.html 1)卷积层 class torch.nn.Conv2d(in_channels, out_channels, kerne ...
- torch.nn.LSTM()函数维度详解
123456789101112lstm=nn.LSTM(input_size, hidden_size, num_la ...
- torch.nn.MSELoss()函数解读
转载自:https://www.cnblogs.com/tingtin/p/13902325.html
- pytorch函数之nn.Linear
class torch.nn.Linear(in_features,out_features,bias = True )[来源] 对传入数据应用线性变换:y = A x+ b 参数: in_featu ...
- PyTorch官方中文文档:torch.nn
torch.nn Parameters class torch.nn.Parameter() 艾伯特(http://www.aibbt.com/)国内第一家人工智能门户,微信公众号:aibbtcom ...
- PyTorch里面的torch.nn.Parameter()
在刷官方Tutorial的时候发现了一个用法self.v = torch.nn.Parameter(torch.FloatTensor(hidden_size)),看了官方教程里面的解释也是云里雾里, ...
随机推荐
- hdu5521 ( Dijkstra )
题目描述: 给定一张图,问从1和n相遇的最短时间. 这道题的输入比较特殊,不能直接存,所以怎么输入,怎么存取,只要可以访问到一个节点的相邻节点就可以,由于spfa算法的时间复杂度为m*n,而Dijks ...
- EF 连接MySql
使用EntityFramework6连接MySql数据库(db first方式) http://www.cnblogs.com/24la/archive/2014/04/03/ef6-mysql.ht ...
- join示例分析
join示例分析 public class TestJoin { public static void main(String[] args) throws InterruptedException ...
- 【WIP】Ruby JSON
创建: 2018/03/22 以后有空补上 注: JSON.generate 参数只能是Obejct或者Array, 不可以是Hash https://docs.ruby-lang.org/ja/la ...
- bzoj 4184: shallot【线性基+时间线段树】
学到了线段树新姿势! 先离线读入,根据时间建一棵线段树,每个节点上开一个vector存这个区间内存在的数(使用map来记录每个数出现的一段时间),然后在线段树上dfs,到叶子节点就计算答案. 注意!! ...
- Code:Blocks 中文乱码问题原因分析和解决方法
下面说说修改的地方. 1.修改源文件保存编码在:settings->Editor->gernal settings 看到右边的Encoding group Box了吗?如下图所示: Use ...
- Poj 2516 Minimum Cost (最小花费最大流)
题目链接: Poj 2516 Minimum Cost 题目描述: 有n个商店,m个仓储,每个商店和仓库都有k种货物.嘛!现在n个商店要开始向m个仓库发出订单了,订单信息为当前商店对每种货物的需求 ...
- HDU 5558 后缀数组
思路: 这是一个错误的思路, 因为数据水才过= = 首先求出来后缀数组 把rank插到set里 每回差i两边离i近的rank值,更新 如果LCP相同,暴力左(右)继续更新sa的最小值 //By Sir ...
- Java中的流(1)流简介
简介 1.在java中stream代表一种数据流(源),java.io的底层数据元.(比作成水管)2.InputStream 比作进水管,水从里面流向你,你要接收,read3.OutputStream ...
- java getDocumentBase() 得到的文件夹路径
参考一个百度知道上的回答 举例说来,假设你的项目文件是xx,而这个xx文件夹是在D盘下的yy文件夹里,即项目文件的完整路径D:\yy\xx,则编译运行文件后,在xx文件夹里会产生名为build的文件夹 ...