深入Pytorch微分传参
导数
这段代码揭示了多个变量的微分以及如何求解loss为向量的导数
m1 = Variable(torch.ones((3,2)), requires_grad=True)
m2 = Variable(torch.ones((3,2))*2, requires_grad=True)
m3 = Variable(torch.ones((3,2))*4, requires_grad=True)
x1 = m1*m2
x2 = x1 *m3
y = x1 + x2
gradients= torch.ones((3,2))
y.backward(gradients)
print(f"m1 grad:{m1.grad}, \n m2 grad:{m2.grad}, \n m3 grad:{m3.grad}, \n x1 grad:{x1.grad}, \n x2 grad:{x2.grad}, \n y grad:{y.grad}")
深入导数--hook机制
hook机制的详细解释
这段代码解释了导数是如何自动计算保存的,
import torch
from torch.autograd import Variable
def register_hook(self, hook):
r"""Registers a backward hook.
The hook will be called every time a gradient with respect to the
Tensor is computed. The hook should have the following signature::
hook(grad) -> Tensor or None
The hook should not modify its argument, but it can optionally return
a new gradient which will be used in place of :attr:`grad`.
This function returns a handle with a method ``handle.remove()``
that removes the hook from the module.
Example::
>>> v = torch.tensor([0., 0., 0.], requires_grad=True)
>>> h = v.register_hook(lambda grad: grad * 2) # double the gradient
>>> v.backward(torch.tensor([1., 2., 3.]))
>>> v.grad
2
4
6
[torch.FloatTensor of size (3,)]
>>> h.remove() # removes the hook
"""
if not self.requires_grad:
raise RuntimeError("cannot register a hook on a tensor that "
"doesn't require gradient")
if self._backward_hooks is None:
self._backward_hooks = OrderedDict()
if self.grad_fn is not None:
self.grad_fn._register_hook_dict(self)
handle = hooks.RemovableHandle(self._backward_hooks)
self._backward_hooks[handle.id] = hook
return handle
v = Variable(torch.Tensor([2, 2, 2]), requires_grad=True)
h = v.register_hook(lambda grad: grad * grad) # double the gradient
v.backward(torch.Tensor([1, 1, 2]))
#先计算原始梯度,再进hook,获得一个新梯度。
print(v.grad.data)
# print(v.data)
# v.grad.data=torch.Tensor([0, 0, 0]) 梯度不置0就会根据hook自动累加
v.backward(torch.Tensor([1, 1, 1]))
print(v.grad.data)
# print(v.data)
h.remove() # removes the hook
使用with torch.no_grad()
with torch.no_grad():
train_l = loss(net(features, w, b), labels)
print('epoch %d, loss %f' % (epoch + 1, train_l.mean().numpy()))
SGD
这段代码揭示了一个最简单运用梯度下降的模型
import torch
from torch.autograd import Variable
from torch.distributions import normal
NUMBER = 100
# X = normal.Normal(loc = 0, scale = 1).sample((1, NUMBER))
X = torch.ones((1, NUMBER))*NUMBER
X= Variable(X, requires_grad=False)
b = torch.ones(X.shape[0])
b.requires_grad=True
epoch = 200
for i in range(epoch):
loss = torch.sum((b-X) ** 2)
b.grad = Variable(torch.zeros(X.shape[0])) #梯度置0
loss.backward()
b.data = b.data- b.grad * (1/NUMBER)/10
if not i%10:
print(f" {i} b is: {b}, b.grad is: {b.grad}")
深入Pytorch微分传参的更多相关文章
- Oracle 用Drapper进行like模糊传参查询需要在参数值前后带%符合
Oracle 用Drapper进行like模糊传参查询需要在参数值前后带%符合 string sqlstr="select * from tblname where name like ...
- Angular页面传参的四种方法
1. 基于ui-router的页面跳转传参 (1)在Angular的app.js中用ui-route定义路由,比如有两个页面, 一个页面(producers.html)放置了多个producers,点 ...
- 使用java传参调用exe并且获取程序进度和返回结果的一种方法
文章版权由作者李晓晖和博客园共有,若转载请于明显处标明出处:http://www.cnblogs.com/naaoveGIS/ 1.背景 在某个项目中需要考虑使用java后台调用由C#编写的切图程序( ...
- Oracle Sales Cloud:报告和分析(BIEE)小细节2——利用变量和过滤器传参(例如,根据提示展示不同部门的数据)
在上一篇随笔中,我们建立了部门和子部门的双提示,并将部门和子部门做了关联.那么,本篇随笔我们重点介绍利用建好的双提示进行传参. 在操作之前,我们来看一个报告和分析的具体需求: [1] 两个有关联的提示 ...
- js动态绑定click事件时function传参问题
今天碰到了这样一个问题,我在javascript中动态创建了一个button, 然后我想给改button添加click事件,绑定的function想要传入一个变量参数, 一开始我想直接通过函数传参传进 ...
- C#进阶系列——WebApi 接口参数不再困惑:传参详解
前言:还记得刚使用WebApi那会儿,被它的传参机制折腾了好久,查阅了半天资料.如今,使用WebApi也有段时间了,今天就记录下API接口传参的一些方式方法,算是一个笔记,也希望能帮初学者少走弯路.本 ...
- 点击div 跳转并通过URL传参
点击div前要先给div绑定要传的参数: //给panel绑定自定义属性,方便在跳转时传带参数,键/值对排列 panel.attr("user_age",user_age); pa ...
- 纯html页面之间传参
//页面引入//传参方法,可解析url参数 (function($){ $.getUrlParam = function(name) { var reg = new RegExp("(^|& ...
- ★★★Oracle sql 传参特别注意★★★
最近遇到一个非常烦人的问题,用传参的方式执行sql语句结果老是报 Oracle ORA-01722: 无效数字 一直无法找到原因. 表结构大致如下: table test_station ( tblR ...
随机推荐
- SpringBoot中使用Zuul
Zuul提供了服务网关的功能,可以实现负载均衡.反向代理.动态路由.请求转发等功能.Zuul大部分功能是通过过滤器实现的,除了标准的四种过滤器类型,还支持自定义过滤器. 使用@EnableZuulPr ...
- PWA学习笔记(二)
设计与体验 APP Shell: 1.应用从显示内容上可粗略划分为内容部分和外壳部分,App Shell 就是外壳部分,即页面的基本结构 2.它不仅包括用户能看到的页面框架部分,还包括用户看不到的代码 ...
- 《数据挖掘导论》实验课——实验七、数据挖掘之K-means聚类算法
实验七.数据挖掘之K-means聚类算法 一.实验目的 1. 理解K-means聚类算法的基本原理 2. 学会用python实现K-means算法 二.实验工具 1. Anaconda 2. skle ...
- VS2017初学者如何打开右侧的解决方案资源管理器
- 因果推理的春天系列序 - 数据挖掘中的Confounding, Collidar, Mediation Bias
序章嘛咱多唠两句.花了大半个月才反反复复,断断续续读完了图灵奖得主Judea Pearl的The Book of WHY,感觉先读第四章的案例会更容易理解前三章相对抽象的内容.工作中对于归因问题迫切的 ...
- OpenGL 之 Compute Shader(通用计算并行加速)
平常我们使用的Shader有顶点着色器.几何着色器.片段着色器,这几个都是为光栅化图形渲染服务的,OpenGL 4.3之后新出了一个Compute Shader,用于通用计算并行加速,现在对其进行介绍 ...
- [专题总结]初探插头dp
彻彻底底写到自闭的一个专题. 就是大型分类讨论,压行+宏定义很有优势. 常用滚动数组+哈希表+位运算.当然还有轮廓线. Formula 1: 经过所有格子的哈密顿回路数. 每个非障碍点必须有且仅有2个 ...
- 改变JAVA窗体属性的操作方法
在本篇内容里小编给大家详细分析了关于改变JAVA窗体属性的操作方法和步骤,需要的朋友们学习下. 若将JDK版本升级到最新版本,Java窗体就可以简单实现窗体的透明效果,用户可以通过拉动滑块(Slide ...
- opencv---(腐蚀、膨胀、边缘检测、轮廓检索、凸包、多边形拟合)
一.腐蚀(Erode) 取符合模板的点, 用区域最小值代替中心位置值(锚点) 作用: 平滑对象边缘.弱化对象之间的连接. opencv 中相关函数:(erode) // C++ /** shape: ...
- (三十六)c#Winform自定义控件-步骤控件-HZHControls
官网 http://www.hzhcontrols.com 前提 入行已经7,8年了,一直想做一套漂亮点的自定义控件,于是就有了本系列文章. GitHub:https://github.com/kww ...