一、torch.nn简介

官网地址:

torch.nn — PyTorch 2.0 documentation

1. torch.nn中的函数简介

  • Containers:神经网络的骨架

  • Convolution Layers:卷积层

  • Pooling layers:池化层

  • Padding Layers:Padding

  • Non-linear Activations:非线性激活

  • Normalization Layers:正则化层

还有其他函数,详情可以看官方文档。以上这些函数构成了神经网络的基本操作。

2. torch.nn中Containers函数的介绍

Containers一共有六个模块:

  • Module:对于所有神经网络提供一个基本的骨架,一般定义一个神经网络用如下代码。其中,Model代表模型的名称,nn.Module就是继承了这个类的模板。然后我们先用__init__初始化,其中super(Model,self).__init__()指的是对父类进行初始化,后面的部分是根据自己构建的神经网络个性化定制的。之后我们使用forword函数对输入数据进行计算,也可以这么理解:对于一个神经网络,首先输入数据-->使用forword函数计算数据-->输出数据,这个过程也叫前向传播
import torch.nn as nn
import torch.nn.functional as F class Model(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 20, 5)
self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x):
x = F.relu(self.conv1(x))
return F.relu(self.conv2(x))
  • Sequential

  • ModuleList

  • ModuleDict

  • ParameterList

  • ParameterDict

二、实操nn.Module

1. 构建一个简单的神经网络

  • 一些小技巧:在写__init__super函数时,pycharm点击下面这个按钮就可以自动补全:

  • 下面构建一个很简单的神经网络,具体作用就是把输入数据+1然后返回,之后调用这个神经网络:

from torch import nn
import torch #构建一个叫Demo的神经网络
class Demo(nn.Module):
def __init__(self):
super().__init__() def forward(self,input):
output=input+1 #对输入神经网络的数据+1,然后返回
return output #调用神经网络
demo=Demo()
x=torch.tensor(1.0) #输入神经网络的数据
output=demo(x)
print(output) #输出神经网络的数据

[Run] tensor(2.)

2. 神经网络运行过程

为了更好地说明上面代码的运行过程,把debug打到第14行的demo=Demo()代码上,并点击Step into My Code

之后一直点击Step into My Code,就可以看到代码的运行过程如下:

  • 在调用demo=Demo()后,首先使用super().__init__()对\(nn.Module\)进行初始化

  • 然后设定输入值x,并使用demo(x)将该值传入到forword函数中

  • forword函数将该值进行加一,并返回output

  • 最后将返回的output输出

深度学习(六)——神经网络的基本骨架:nn.Module的使用的更多相关文章

  1. Spark MLlib Deep Learning Convolution Neural Network (深度学习-卷积神经网络)3.1

    3.Spark MLlib Deep Learning Convolution Neural Network (深度学习-卷积神经网络)3.1 http://blog.csdn.net/sunbow0 ...

  2. Spark MLlib Deep Learning Convolution Neural Network (深度学习-卷积神经网络)3.3

    3.Spark MLlib Deep Learning Convolution Neural Network(深度学习-卷积神经网络)3.3 http://blog.csdn.net/sunbow0 ...

  3. Spark MLlib Deep Learning Convolution Neural Network (深度学习-卷积神经网络)3.2

    3.Spark MLlib Deep Learning Convolution Neural Network(深度学习-卷积神经网络)3.2 http://blog.csdn.net/sunbow0 ...

  4. 针对深度学习(神经网络)的AI框架调研

    针对深度学习(神经网络)的AI框架调研 在我们的AI安全引擎中未来会使用深度学习(神经网络),后续将引入AI芯片,因此重点看了下业界AI芯片厂商和对应芯片的AI框架,包括Intel(MKL CPU). ...

  5. 深度学习 循环神经网络 LSTM 示例

    最近在网上找到了一个使用LSTM 网络解决  世界银行中各国 GDP预测的一个问题,感觉比较实用,毕竟这是找到的唯一一个可以正确运行的程序. #encoding:UTF-8 import pandas ...

  6. 深度学习——卷积神经网络 的经典网络(LeNet-5、AlexNet、ZFNet、VGG-16、GoogLeNet、ResNet)

    一.CNN卷积神经网络的经典网络综述 下面图片参照博客:http://blog.csdn.net/cyh_24/article/details/51440344 二.LeNet-5网络 输入尺寸:32 ...

  7. AI、机器学习、深度学习、神经网络

    1.AI:人工智能(Artificial Intelligence) 2.机器学习:(Machine Learning, ML) 3.深度学习:Deep Learning 人工功能的实现是让机器自己学 ...

  8. 【ARM-Linux开发】【CUDA开发】【深度学习与神经网络】Jetson Tx2安装相关之三

    JetPack(Jetson SDK)是一个按需的一体化软件包,捆绑了NVIDIA®Jetson嵌入式平台的开发人员软件.JetPack 3.0包括对Jetson TX2 , Jetson TX1和J ...

  9. 【深度学习与神经网络】深度学习的下一个热点——GANs将改变世界

    本文作者 Nikolai Yakovenko 毕业于哥伦比亚大学,目前是 Google 的工程师,致力于构建人工智能系统,专注于语言处理.文本分类.解析与生成. 生成式对抗网络-简称GANs-将成为深 ...

  10. 小白学习之pytorch框架(1)-torch.nn.Module+squeeze(unsqueeze)

    我学习pytorch框架不是从框架开始,从代码中看不懂的pytorch代码开始的 可能由于是小白的原因,个人不喜欢一些一下子粘贴老多行代码的博主或者一些弄了一堆概念,导致我更迷惑还增加了畏惧的情绪(个 ...

随机推荐

  1. 0x03.api接口

    API接口安全 HTTP类接口 SOAP----->http+xml REST------>传递资源 RPC类接口(非web):用于远程调用,类似于客户端和服务端.如,登录的时候,进入服务 ...

  2. 使用nacos配置无效,原因:项目中 gateway服务配置的 application的name:@artifactId@ 和nacos上配置的DataID 不一致导致

    遇到一个问题,项目启动后一致无法正常登陆进入后端,登陆时一直报错返回null,排查后发现是自己粗心,项目中 gateway服务配置的 application的name:@artifactId@   和 ...

  3. 28、错误error

    1.是什么? 在实际的项目中,我们希望通过程序的错误信息快速定位问题,但是又不喜欢错误处理:代码就会很冗余又啰嗦.Go语言没有提供类似Java.C#语言中的try...catch异常处理方法,而是通过 ...

  4. AI换脸利器!Roop下载分享

    ​ 前段时间给大家介绍过换脸界最强的Rope,感兴趣的小伙伴可以戳戳手指 传送门:https://blog.csdn.net/S_eashell?spm=1011.2415.3001.5343 今天要 ...

  5. TensorFlow C++ 初始化 Tensor 内存 到GPU 内存

    最近使用TensorFlow C++版本实现神经网络的部署,我通过GPU 处理得到网络的输入值,因此输入值在GPU内存上保存, TF 输入tensor 的调用语句为 Tensor inputTenso ...

  6. 1.7每日总结-vue链mysql4

    新建/server/router.js,用于配置对应路由let express = require('express')let router = express.Router()let user = ...

  7. tmux 增加历史回滚缓冲区 buffer

    tmux 默认回滚 2000 行,如果要查看更多记录(比如编译报错)可以在.tmux.conf文件中增加一行 set -g history-limit 5000 重启 tmux session 生效

  8. 解锁华为云AI如何助力无人车飞驰“新姿势”,大赛冠军有话说

    摘要:在2020年第二届华为云人工智能大赛•无人车挑战杯赛道中,"华中科技大学无人车一队"借助华为云一站式AI开发与管理平台ModelArts及HiLens端云协同AI开发应用平台 ...

  9. GaussDB拿下的安全认证CC EAL4+究竟有多难?

    摘要:近日,经过全球知名独立认证机构SGS Brightsight实验室的安全评估,华为云GaussDB企业级分布式数据库内核获得全球权威信息技术安全性评估标准CC EAL4+级别认证 本文分享自华为 ...

  10. 字节跳动基于ClickHouse优化实践之“高可用”

    更多技术交流.求职机会,欢迎关注字节跳动数据平台微信公众号,回复[1]进入官方交流群 相信大家都对大名鼎鼎的ClickHouse有一定的了解了,它强大的数据分析性能让人印象深刻.但在字节大量生产使用中 ...