转自:https://sherlockliao.github.io/2017/07/10/backward/

backward只能被应用在一个标量上,也就是一个一维tensor,或者传入跟变量相关的梯度。

特别注意Variable里面默认的参数requires_grad=False,所以这里我们要重新传入requires_grad=True让它成为一个叶子节点

对其求偏导:

 import torch as t
from torch.autograd import Variable as v # simple gradient
a = v(t.FloatTensor([2, 3]), requires_grad=True)
b = a + 3
c = b * b * 3
out = c.mean()
out.backward()
print('*'*10)
print('=====simple gradient======')
print('input')
print(a.data)
print('compute result is')
print(out.data[0])
print('input gradients are')
print(a.grad.data)

下面研究一下如何能够对非标量的情况下使用backward。backward里传入的参数是每次求导的一个系数。

首先定义好输入m=(x1,x2)=(2,3),然后我们做的操作就是n=,这样我们就定义好了一个向量输出,结果第一项只和x1有关,结果第二项只和x2有关,那么求解这个梯度,

 # backward on non-scalar output
m = v(t.FloatTensor([[2, 3]]), requires_grad=True)
n = v(t.zeros(1, 2))
n[0, 0] = m[0, 0] ** 2
n[0, 1] = m[0, 1] ** 3
n.backward(t.FloatTensor([[1, 1]]))
print('*'*10)
print('=====non scalar output======')
print('input')
print(m.data)
print('input gradients are')
print(m.grad.data)

jacobian矩阵

对其求导:

k.backward(parameters)接受的参数parameters必须要和k的大小一模一样,然后作为k的系数传回去,backward里传入的参数是每次求导的一个系数。

# jacobian
j = t.zeros(2 ,2)
k = v(t.zeros(1, 2))
m.grad.data.zero_()
k[0, 0] = m[0, 0] ** 2 + 3 * m[0 ,1]
k[0, 1] = m[0, 1] ** 2 + 2 * m[0, 0]
# [1, 0] dk0/dm0, dk1/dm0
k.backward(t.FloatTensor([[1, 0]]), retain_variables=True) # 需要两次反向求导
j[:, 0] = m.grad.data
m.grad.data.zero_()
# [0, 1] dk0/dm1, dk1/dm1
k.backward(t.FloatTensor([[0, 1]]))
j[:, 1] = m.grad.data
print('jacobian matrix is')
print(j)

我们要注意backward()里面另外的一个参数retain_variables=True,这个参数默认是False,也就是反向传播之后这个计算图的内存会被释放,这样就没办法进行第二次反向传播了,所以我们需要设置为True,因为这里我们需要进行两次反向传播求得jacobian矩阵。

PyTorch中的backward [转]的更多相关文章

  1. 关于Pytorch中autograd和backward的一些笔记

    参考自<Pytorch autograd,backward详解>: 1 Tensor Pytorch中所有的计算其实都可以回归到Tensor上,所以有必要重新认识一下Tensor. 如果我 ...

  2. pytorch中tensorboardX的用法

    在代码中改好存储Log的路径 命令行中输入 tensorboard --logdir /home/huihua/NewDisk1/PycharmProjects/pytorch-deeplab-xce ...

  3. pytorch 中的重要模块化接口nn.Module

    torch.nn 是专门为神经网络设计的模块化接口,nn构建于autgrad之上,可以用来定义和运行神经网络 nn.Module 是nn中重要的类,包含网络各层的定义,以及forward方法 对于自己 ...

  4. 转pytorch中训练深度神经网络模型的关键知识点

    版权声明:本文为博主原创文章,遵循CC 4.0 BY-SA版权协议,转载请附上原文出处链接和本声明. 本文链接:https://blog.csdn.net/weixin_42279044/articl ...

  5. pytorch中调用C进行扩展

    pytorch中调用C进行扩展,使得某些功能在CPU上运行更快: 第一步:编写头文件 /* src/my_lib.h */ int my_lib_add_forward(THFloatTensor * ...

  6. 关于Pytorch中accuracy和loss的计算

    这几天关于accuracy和loss的计算有一些疑惑,原来是自己还没有弄清楚. 给出实例 def train(train_loader, model, criteon, optimizer, epoc ...

  7. 【PyTorch】PyTorch中的梯度累加

    PyTorch中的梯度累加 使用PyTorch实现梯度累加变相扩大batch PyTorch中在反向传播前为什么要手动将梯度清零? - Pascal的回答 - 知乎 https://www.zhihu ...

  8. PyTorch中的C++扩展

    今天要聊聊用 PyTorch 进行 C++ 扩展. 在正式开始前,我们需要了解 PyTorch 如何自定义module.这其中,最常见的就是在 python 中继承torch.nn.Module,用 ...

  9. PyTorch中的Batch Normalization

    Pytorch中的BatchNorm的API主要有: 1 torch.nn.BatchNorm1d(num_features, 2 3 eps=1e-05, 4 5 momentum=0.1, 6 7 ...

随机推荐

  1. 题解-PKUWC2018 Minimax

    Problem loj2537 Solution pkuwc2018最水的一题,要死要活调了一个多小时(1h59min) 我写这题不是因为它有多好,而是为了保持pkuwc2018的队形,与这题类似的有 ...

  2. 高效的多维空间点索引算法 — Geohash 和 Google S2

    原文地址:https://www.jianshu.com/p/7332dcb978b2   引子 每天我们晚上加班回家,可能都会用到滴滴或者共享单车.打开 app 会看到如下的界面:     app ...

  3. 关于PJ 10.27

    题1 : Orchestra 题意: 给你一个 n*m 的矩阵,其中有一些点是被标记过的. 现在让你求标记个数大于 k 个的二维区间个数. n.m .k 最大是 10 . 分析: part 1: 10 ...

  4. mysql alter 效率

    2017年9月15日 10:36:54 星期五 今天遇到一个效率问题记下来: 场景: mysql要更改一下表字段的注释, 因为sql语句问题, 导致更新了整张表.. 错误: ) UNSIGNED ' ...

  5. Solidworks设计电路外形导入AltiumDesigner

    将实际设计好的三维电路图的底板(单个零件模式下)轮廓另存为dwf或者dwg 这时候会出现一个选项框,需要进行一些设置 单位选择mm,这个按照自己的需求选择单位 单位映射选择为1mm,也就是1:1的比例 ...

  6. Codeforces 671D Roads in Yusland [树形DP,线段树合并]

    洛谷 Codeforces 这是一个非正解,被正解暴踩,但它还是过了. 思路 首先很容易想到DP. 设\(dp_{x,i}\)表示\(x\)子树全部被覆盖,而且向上恰好延伸到\(dep=i\)的位置, ...

  7. 随机生成n位随机数(包含大写字母、小写字母、数字)

    package com.java.weiju; import java.security.SecureRandom; import java.util.Date; import java.util.R ...

  8. js——事件冒泡与捕获小例子

    布局代码 #outer{ width: 300px; height: 300px; background: red; } #inner{ width: 200px; height: 200px; ba ...

  9. Oracle 网络监听配置管理

    Oracle 网络配置与管理 详细信息可以参考以下信息: [学习目标] 一.原理解析 二.配置侦听器(LISTENER) 三.配置客户端网络服务名 四.关于注册 五.查询某服务是静态还是动态注册 Or ...

  10. Confluence 6 管理协同编辑 - 关于 Synchrony

    协同编辑能够让项目小组中的协同合作达到下一个高度.这个页面对相关协同编辑中的问题进行了讨论,能够提供给你所有希望了解的内容. 进入 Collaborative editing 页面来获得项目小组是如何 ...