参考:1. pytorch学习笔记(九):PyTorch结构介绍

2.pytorch学习笔记(七):pytorch hook 和 关于pytorch backward过程的理解

3.Pytorch入门学习(三):Neural Networks

4.forward

神经网络的典型处理如下所示:

1. 定义可学习参数的网络结构(堆叠各层和层的设计);
2. 数据集输入;
3. 对输入进行处理(由定义的网络层进行处理),主要体现在网络的前向传播;
4. 计算loss ,由Loss层计算;
5. 反向传播求梯度;
6. 根据梯度改变参数值,最简单的实现方式(SGD)为:
   weight = weight - learning_rate * gradient
下面是利用PyTorch定义深度网络层(Op)示例:

  1.  
    class FeatureL2Norm(torch.nn.Module):
  2.  
    def __init__(self):
  3.  
    super(FeatureL2Norm, self).__init__()
  4.  
     
  5.  
    def forward(self, feature):
  6.  
    epsilon = 1e-6
  7.  
    # print(feature.size())
  8.  
    # print(torch.pow(torch.sum(torch.pow(feature,2),1)+epsilon,0.5).size())
  9.  
    norm = torch.pow(torch.sum(torch.pow(feature,2),1)+epsilon,0.5).unsqueeze(1).expand_as(feature)
  10.  
    return torch.div(feature,norm)
  1.  
    class FeatureRegression(nn.Module):
  2.  
    def __init__(self, output_dim=6, use_cuda=True):
  3.  
    super(FeatureRegression, self).__init__()
  4.  
    self.conv = nn.Sequential(
  5.  
    nn.Conv2d(225, 128, kernel_size=7, padding=0),
  6.  
    nn.BatchNorm2d(128),
  7.  
    nn.ReLU(inplace=True),
  8.  
    nn.Conv2d(128, 64, kernel_size=5, padding=0),
  9.  
    nn.BatchNorm2d(64),
  10.  
    nn.ReLU(inplace=True),
  11.  
    )
  12.  
    self.linear = nn.Linear(64 * 5 * 5, output_dim)
  13.  
    if use_cuda:
  14.  
    self.conv.cuda()
  15.  
    self.linear.cuda()
  16.  
     
  17.  
    def forward(self, x):
  18.  
    x = self.conv(x)
  19.  
    x = x.view(x.size(0), -1)
  20.  
    x = self.linear(x)
  21.  
    return x

由上例代码可以看到,不论是在定义网络结构还是定义网络层的操作(Op),均需要定义forward函数,下面看一下PyTorch官网对PyTorch的forward方法的描述:

那么调用forward方法的具体流程是什么样的呢?具体流程是这样的:

以一个Module为例:
1. 调用module的call方法
2. module的call里面调用module的forward方法
3. forward里面如果碰到Module的子类,回到第1步,如果碰到的是Function的子类,继续往下
4. 调用Function的call方法
5. Function的call方法调用了Function的forward方法。
6. Function的forward返回值
7. module的forward返回值
8. 在module的call进行forward_hook操作,然后返回值
上述中“调用module的call方法”是指nn.Module 的__call__方法。定义__call__方法的类可以当作函数调用,具体参考Python的面向对象编程。也就是说,当把定义的网络模型model当作函数调用的时候就自动调用定义的网络模型的forward方法。nn.Module 的__call__方法部分源码如下所示:

  1.  
    def __call__(self, *input, **kwargs):
  2.  
    result = self.forward(*input, **kwargs)
  3.  
    for hook in self._forward_hooks.values():
  4.  
    #将注册的hook拿出来用
  5.  
    hook_result = hook(self, input, result)
  6.  
    ...
  7.  
    return result

可以看到,当执行model(x)的时候,底层自动调用forward方法计算结果。具体示例如下:

  1.  
    class Function:
  2.  
    def __init__(self):
  3.  
    ...
  4.  
    def forward(self, inputs):
  5.  
    ...
  6.  
    return outputs
  7.  
    def backward(self, grad_outs):
  8.  
    ...
  9.  
    return grad_ins
  10.  
    def _backward(self, grad_outs):
  11.  
    hooked_grad_outs = grad_outs
  12.  
    for hook in hook_in_outputs:
  13.  
    hooked_grad_outs = hook(hooked_grad_outs)
  14.  
    grad_ins = self.backward(hooked_grad_outs)
  15.  
    hooked_grad_ins = grad_ins
  16.  
    for hook in hooks_in_module:
  17.  
    hooked_grad_ins = hook(hooked_grad_ins)
  18.  
    return hooked_grad_ins

model = LeNet()
y = model(x)
如上则调用网络模型定义的forward方法。

PyTorch之前向传播函数自动调用forward的更多相关文章

  1. pytorch 调用forward 的具体流程

    forward方法的具体流程: 以一个Module为例:1. 调用module的call方法2. module的call里面调用module的forward方法3. forward里面如果碰到Modu ...

  2. 根据判断PC浏览器类型和手机屏幕像素自动调用不同CSS的代码

    1.媒体查询方法在 css 里面这样写 -------------------- @media screen and (min-width: 320px) and (max-width: 480px) ...

  3. paintEvent(QPaintEvent*)是系统自动调用的

    qt中函数paintEvent(QPaintEvent*)是被系统自动调用. paintEvent(QPaintEvent *)函数是QWidget类中的虚函数,用于ui的绘制,会在多种情况下被其他函 ...

  4. QT5.3无法自动调用incomingConnection函数的问题(4.7没有这个问题)

    最近将qt4.7的一个工程移到5.3,遇到了几个麻烦事,主要是这个incomingConnection监听后无法自动调用的问题,在4.7上是完全没有问题的,到了5.3就不行,网上也查了下,网友们都是放 ...

  5. 如果浏览器自动调用quirks模式打开的话

    (从已经死了一次又一次终于挂掉的百度空间人工抢救出来的,发表日期 2014-03-21) 则肯定你的html的声明,没有写好. 今天遇到几个,前面莫名其妙的多了个空格(在网页上看源码是多空格,复制到n ...

  6. PHP中 对象自动调用的方法:__set()、__get()、__tostring()

    总结: (1)__get($property_name):获取私有属性$name值时,此对象会自动调用该方法,将属性name值传给参数$property_name,通过这个方法的内部 执行,返回我们传 ...

  7. C++构造函数的自动调用(调用一个父类的构造函数,有显性调用最好,否则就默认调用无参数的构造函数)——哲学思想:不调用怎么初始化父类的成员数据和VMT?

    我总是记不住构造函数的特点,关键还是没有领会那个哲学思想:父类的构造函数一方面要初始化它自己的成员数据,另一方面也要建立它自己的VMT呀!心里默念一百遍:一定调用父类构造函数,一定调用父类构造函数,一 ...

  8. Object之魔术函数__toString() 直接输出对象引用时自动调用

    __toString()是快速获取对象的字符串信息的便捷方式 在直接输出对象引用时自动调用的方法. __toString()的作用 当我们调试程序时,需要知道是否得出正确的数据.比如打印一个对象时,看 ...

  9. QT5.3无法自动调用incomingConnection函数的问题

    最近将qt4.7的一个工程移到5.3,遇到了几个麻烦事,主要是这个incomingConnection监听后无法自动调用的问题,在4.7上是完全没有问题的,到了5.3就不行,网上也查了下,网友们都是放 ...

随机推荐

  1. CentOS8/RHEL8--恢复root用户密码及简易加固GRUB

    CentOS8/RHEL8--简易加固GRUB 今天突然想到放在数据中心的虚拟化平台下的Linux服务器,都是采用默认方式安装的,没有设置太多的安全选项,如果有恶意用户重启服务器后,通过GRUB调整启 ...

  2. go 文件操作 io

    package main import ( "fmt" "os" ) func main() { //打开文件 //概念说明: file 的叫法 //1. fi ...

  3. 详细介绍Java中的堆、栈和常量池

    下面主要介绍JAVA中的堆.栈和常量池: 1.寄存器 最快的存储区, 由编译器根据需求进行分配,我们在程序中无法控制. 2. 栈 存放基本类型的变量数据和对象的引用,但对象本身不存放在栈中,而是存放在 ...

  4. oralce默认语言

    默认语言设置可以确定数据库如何支持与区域设置相关的信息,例如: 日和月份的名称及其缩写 A.M..P.M..A.D. 和 B.C. 的等价表示方法的符号 指定 ORDER BY SQL 子句时字符数据 ...

  5. 【软件安装】python安装numpy、scipy

    python2.7开发环境,若为python3.4的环境则下载对应的软件 系统为64为windows环境,显然不同于32的环境,更繁琐,所谓的网友教程也不尽人意. 安装numpy 下载地址:http: ...

  6. Java8中的LocalDateTime工具类

    网上搜索了半天都没有找到Java8的LocalDateTime的工具类,只好自己写了一个,常用功能基本都有.还在用Date的Java同道该换换了. 个人项目地址:https://github.com/ ...

  7. js写的滑动解锁

    css部分 *{ margin:; padding:; box-sizing: border-box; -webkit-touch-callout: none; -webkit-user-select ...

  8. bzoj1688 疾病管理

    Description Alas! A set of D (1 <= D <= 15) diseases (numbered 1..D) is running through the fa ...

  9. 解决 win10 pycurl安装出错 Command "python setup.py egg_info" failed with error code 10 编译安装包 安装万金油

    版权声明:本文为博主原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明.本文链接:https://blog.csdn.net/yexiaohhjk/article/de ...

  10. iOS应用国际化教程

    开发一款伟大的iOS应用程序是件了不起的事情,但是还有比优秀的代码.华丽的设计以及直观化交互更多的事要做.跻身在App Store排行榜前列还需要正合时宜的产品营销.扩大用户群的能力.实用的工具以及尽 ...