pytorch中多个loss回传的参数影响示例
写了一段代码如下:
import torch
import torch.nn as nn
import torch.nn.functional as F class Test(nn.Module):
def __init__(self):
super(Test, self).__init__()
self.fc1 = nn.Linear(5, 4)
self.fc2 = nn.Linear(4, 3)
self.fc3 = nn.Linear(4, 3) def forward(self, x):
mid = self.fc1(x)
out1 = self.fc2(mid)
out2 = self.fc3(mid)
return out1, out2 x = torch.randn((3, 5))
y = torch.torch.randint(3, (3,), dtype=torch.int64)
model = Test()
model.train()
optim = torch.optim.RMSprop(model.parameters(), lr=0.001) print(model.fc2.weight)
print(model.fc3.weight)
for i in range(5):
out1, out2 = model(x)
loss1 = F.cross_entropy(out1, y)
loss2 = F.cross_entropy(out2, y)
loss = loss1 + loss2
optim.zero_grad()
loss.backward()
optim.step()
print("-------------after-----------")
print(model.fc2.weight)
print(model.fc3.weight)
在loss.backward()处分别更换为loss1.backward()和loss2.backward(),观察fc2和fc3层的参数变化。
得出的结论为:loss2只影响fc3的参数,loss1只影响fc2的参数。
(粗略分析,抛砖引玉)
pytorch中多个loss回传的参数影响示例的更多相关文章
- 关于Pytorch中accuracy和loss的计算
这几天关于accuracy和loss的计算有一些疑惑,原来是自己还没有弄清楚. 给出实例 def train(train_loader, model, criteon, optimizer, epoc ...
- 在ASP.NET MVC中以post方式传递数组参数的示例
最近在工作中用到了在ASP.NET MVC中以post方式传递数组参数的情况,记录下来,以供参考. 一.准备参数对象 在本例中,我会传递两个数组参数:一个字符串数组,一个自定义对象数组.这个自定义对象 ...
- 在ASP.NET MVC中以post方式传递数组参数的示例【转】
最近在工作中用到了在ASP.NET MVC中以post方式传递数组参数的情况,记录下来,以供参考. 一.准备参数对象 在本例中,我会传递两个数组参数:一个字符串数组,一个自定义对象数组.这个自定义对象 ...
- PyTorch中view的用法
相当于numpy中resize()的功能,但是用法可能不太一样. 我的理解是: 把原先tensor中的数据按照行优先的顺序排成一个一维的数据(这里应该是因为要求地址是连续存储的),然后按照参数组合成其 ...
- Pytorch中的自动求导函数backward()所需参数含义
摘要:一个神经网络有N个样本,经过这个网络把N个样本分为M类,那么此时backward参数的维度应该是[N X M] 正常来说backward()函数是要传入参数的,一直没弄明白backward需要传 ...
- ARTS-S pytorch中backward函数的gradient参数作用
导数偏导数的数学定义 参考资料1和2中对导数偏导数的定义都非常明确.导数和偏导数都是函数对自变量而言.从数学定义上讲,求导或者求偏导只有函数对自变量,其余任何情况都是错的.但是很多机器学习的资料和开源 ...
- Pytorch中torch.autograd ---backward函数的使用方法详细解析,具体例子分析
backward函数 官方定义: torch.autograd.backward(tensors, grad_tensors=None, retain_graph=None, create_graph ...
- 【PyTorch】PyTorch中的梯度累加
PyTorch中的梯度累加 使用PyTorch实现梯度累加变相扩大batch PyTorch中在反向传播前为什么要手动将梯度清零? - Pascal的回答 - 知乎 https://www.zhihu ...
- pytorch中tensorboardX的用法
在代码中改好存储Log的路径 命令行中输入 tensorboard --logdir /home/huihua/NewDisk1/PycharmProjects/pytorch-deeplab-xce ...
随机推荐
- 2015-2016 ACM-ICPC, NEERC, Northern Subregional Contest (9/12)
$$2015-2016\ ACM-ICPC,\ NEERC,\ Northern\ Subregional\ Contest$$ \(A.Alex\ Origami\ Squares\) 签到 //# ...
- hdu5627 Clarke and MST (并查集)
Time Limit: 2000/1000 MS (Java/Others) Memory Limit: 65536/65536 K (Java/Others) Total Submission ...
- 2019牛客暑期多校训练营(第七场)B Irreducible Polynomial
传送门 题意: 给你一个n次n+1项式(An*X^n+A(n-1)*X^(n-1)...A*X+A0),将系数An都给你,问你这个多项式是不是一个不可约多项式,可约多项式就是类型(x+1)*(x+2) ...
- Codeforces Round #652 (Div. 2)D. TediousLee 推导
题意: Rooted Dead Bush (RDB) of level 1是只有一个点,如下图 当(RDB) of level i变成(RDB) of level i+1的时候,每一个顶点要进行下面的 ...
- Network of Schools POJ - 1236 有向强连通图
//题意://给你n个学校,其中每一个学校都和一些其他学校有交流,但是这些边都是单向的.你至少需要给几个学校//传递消息可以使全部学校都收到消息,第二问你最少添加几条边可以使它变成一个强连通图//题解 ...
- hdu5375 Gray code
Problem Description The reflected binary code, also known as Gray code after Frank Gray, is a binary ...
- Selenium刚玩一会儿,就感受了私人秘书的体验
学习python的过程中,少不了接触第三方库,毕竟作为胶水语言python的强大之处也就是第三方库体量庞大,无疑大大增强了python的战斗力. 有时候想完成网页自动化操作,这时候Selenium进入 ...
- Spring:解决因@Async引起的循环依赖报错
最近项目中使用@Async注解在方法上引起了循环依赖报错: org.springframework.beans.factory.BeanCurrentlyInCreationException: Er ...
- js的变量,作用域,内存
一,基本类型和引用类型的值基本类型的值是按值访问的,引用类型的值是保存在内存中的对象1,动态的属性 只有引用类型的值可以添加属性方法 不能给基本类型添加属性和方法2,复制变量值 复制基本类型的值,两个 ...
- hdu5693D++游戏 区间DP-暴力递归
主要的收获是..如何优化你递推式里面不必要的决策 之前的代码 这个代码在HDU超时了,这就对了..这个复杂度爆炸.. 但是这个思路非常地耿直..那就是只需要暴力枚举删两个和删三个的情况,于是就非常耿直 ...