pytorch中多个loss回传的参数影响示例
写了一段代码如下:
import torch
import torch.nn as nn
import torch.nn.functional as F class Test(nn.Module):
def __init__(self):
super(Test, self).__init__()
self.fc1 = nn.Linear(5, 4)
self.fc2 = nn.Linear(4, 3)
self.fc3 = nn.Linear(4, 3) def forward(self, x):
mid = self.fc1(x)
out1 = self.fc2(mid)
out2 = self.fc3(mid)
return out1, out2 x = torch.randn((3, 5))
y = torch.torch.randint(3, (3,), dtype=torch.int64)
model = Test()
model.train()
optim = torch.optim.RMSprop(model.parameters(), lr=0.001) print(model.fc2.weight)
print(model.fc3.weight)
for i in range(5):
out1, out2 = model(x)
loss1 = F.cross_entropy(out1, y)
loss2 = F.cross_entropy(out2, y)
loss = loss1 + loss2
optim.zero_grad()
loss.backward()
optim.step()
print("-------------after-----------")
print(model.fc2.weight)
print(model.fc3.weight)
在loss.backward()处分别更换为loss1.backward()和loss2.backward(),观察fc2和fc3层的参数变化。
得出的结论为:loss2只影响fc3的参数,loss1只影响fc2的参数。
(粗略分析,抛砖引玉)
pytorch中多个loss回传的参数影响示例的更多相关文章
- 关于Pytorch中accuracy和loss的计算
这几天关于accuracy和loss的计算有一些疑惑,原来是自己还没有弄清楚. 给出实例 def train(train_loader, model, criteon, optimizer, epoc ...
- 在ASP.NET MVC中以post方式传递数组参数的示例
最近在工作中用到了在ASP.NET MVC中以post方式传递数组参数的情况,记录下来,以供参考. 一.准备参数对象 在本例中,我会传递两个数组参数:一个字符串数组,一个自定义对象数组.这个自定义对象 ...
- 在ASP.NET MVC中以post方式传递数组参数的示例【转】
最近在工作中用到了在ASP.NET MVC中以post方式传递数组参数的情况,记录下来,以供参考. 一.准备参数对象 在本例中,我会传递两个数组参数:一个字符串数组,一个自定义对象数组.这个自定义对象 ...
- PyTorch中view的用法
相当于numpy中resize()的功能,但是用法可能不太一样. 我的理解是: 把原先tensor中的数据按照行优先的顺序排成一个一维的数据(这里应该是因为要求地址是连续存储的),然后按照参数组合成其 ...
- Pytorch中的自动求导函数backward()所需参数含义
摘要:一个神经网络有N个样本,经过这个网络把N个样本分为M类,那么此时backward参数的维度应该是[N X M] 正常来说backward()函数是要传入参数的,一直没弄明白backward需要传 ...
- ARTS-S pytorch中backward函数的gradient参数作用
导数偏导数的数学定义 参考资料1和2中对导数偏导数的定义都非常明确.导数和偏导数都是函数对自变量而言.从数学定义上讲,求导或者求偏导只有函数对自变量,其余任何情况都是错的.但是很多机器学习的资料和开源 ...
- Pytorch中torch.autograd ---backward函数的使用方法详细解析,具体例子分析
backward函数 官方定义: torch.autograd.backward(tensors, grad_tensors=None, retain_graph=None, create_graph ...
- 【PyTorch】PyTorch中的梯度累加
PyTorch中的梯度累加 使用PyTorch实现梯度累加变相扩大batch PyTorch中在反向传播前为什么要手动将梯度清零? - Pascal的回答 - 知乎 https://www.zhihu ...
- pytorch中tensorboardX的用法
在代码中改好存储Log的路径 命令行中输入 tensorboard --logdir /home/huihua/NewDisk1/PycharmProjects/pytorch-deeplab-xce ...
随机推荐
- zoj3777 Problem Arrangement(状压dp,思路赞)
The 11th Zhejiang Provincial Collegiate Programming Contest is coming! As a problem setter, Edward i ...
- 【2020杭电多校】Total Eclipse 并查集+思维
题目链接:Total Eclipse 题意: t组输入,给你一个由n个点,m条边构成的图,每一个点的权值是ai.你每一次可以选择一批联通的点,然后让他们的权值都减去1.问最后把所有点的权值都变成0需要 ...
- AcWing 247. 亚特兰蒂斯 (线段树,扫描线,离散化)
题意:给你\(n\)个矩形,求矩形并的面积. 题解:我们建立坐标轴,然后可以对矩形的横坐标进行排序,之后可以遍历这些横坐标,这个过程可以想像成是一条线从左往右扫过x坐标轴,假如这条线是第一次扫过矩形的 ...
- Codeforces Round #658 (Div. 2) C2. Prefix Flip (Hard Version) (构造)
题意:给你两个长度为\(n\)的01串\(s\)和\(t\),可以选择\(s\)的前几位,取反然后反转,保证\(s\)总能通过不超过\(2n\)的操作得到\(t\),输出变换总数,和每次变换的位置. ...
- Kubernets二进制安装(10)之部署主控节点部署调度器服务kube-scheduler
Kubernetes Scheduler是一个策略丰富.拓扑感知.工作负载特定的功能,调度器显著影响可用性.性能和容量.调度器需要考虑个人和集体的资源要求.服务质量要求.硬件/软件/政策约束.亲和力和 ...
- Django分页APP_django-pure-pagination
一.App说明 该App用户Django的数据分页功能 二.安装 pip install django-pure-pagination 三.使用方法 (1)settings注册 INSTALLED_A ...
- VScode 配置c++环境
参考 https://code.visualstudio.com/docs/cpp/config-mingw https://zhuanlan.zhihu.com/p/77645306 主要 http ...
- meidi
最近觉得某些公司的选择题也是很基础,非常值得总结回味.今天做了美的的笔试,20道选择题(单选14+6多选).特此记录如下(部分忘了烦请见谅): 1. 是我昨晚刚刚总结的List,Set,Map的区别: ...
- Leetcode(1)-两数之和
给定一个整数数组和一个目标值,找出数组中和为目标值的两个数. 你可以假设每个输入只对应一种答案,且同样的元素不能被重复利用. 示例: 给定 nums = [2, 7, 11, 15], target ...
- 破解编码面试第六版 - JavaScript
破解编码面试第六版 - JavaScript Cracking the Coding Interview: 189 Programming Questions and Solutions 6th Ed ...