Pytorch 初次使用nn包
计算图和autograd是十分强大的工具,可以定义复杂的操作并自动求导;然而对于大规模的网络,autograd太过于底层。
在构建神经网络时,我们经常考虑将计算安排成层,其中一些具有可学习的参数,它们将在学习过程中进行优化。
TensorFlow里,有类似Keras,TensorFlow-Slim和TFLearn这种封装了底层计算图的高度抽象的接口,这使得构建网络十分方便。
在PyTorch中,包nn完成了同样的功能。nn包中定义一组大致等价于层的模块。一个模块接受输入的tesnor,计算输出的tensor,而且还保存了一些内部状态比如需要学习的tensor的参数等。nn包中也定义了一组损失函数(loss functions),用来训练神经网络。同时nn包中不光有一些激活函数和层操作外,还包含常见的损失函数。
代码如下:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
N, D_in, H, D_out = 64, 1000, 100, 10
#随机生成输入和输出
x = torch.randn(N, D_in, device=device)
y = torch.randn(N, D_out, device=device)
# 使用nn包将我们的模型定义为一系列的层。
# nn.Sequential是包含其他模块的模块,并按顺序应用这些模块来产生其输出。
# 每个线性模块使用线性函数从输入计算输出,并保存其内部的权重和偏差张量。
# 在构造模型之后,我们使用.to()方法将其移动到所需的设备。
model = torch.nn.Sequential(
torch.nn.Linear(D_in,H),
torch.nn.ReLU(),
torch.nn.Linear(H, D_out),
).to(device)
'''
nn包中还有常用的损失函数的定义
MSELoss()中参数reducetion 初始为'mean',为均值,我们使用的是'sum'为和
但是在实践中,通过设置reduction='elementwise_mean'来使用均方误差作为损失更为常见
'''
loss_fn = torch.nn.MSELoss(reduction='elementwise_mean')
learning_rate = 1e-4
for t in range(500):
'''
该操作为前向传播,通过向模型中传入x,进而得到输出y
同时该模块有__call__属性可以像调用函数一样调用他们
这样我们输入张量x,得到了输出张量y_pred
'''
y_pred = model(x)
loss = loss_fn(y_pred,y)
print(t,loss.item())
#运算之前清除梯度
model.zero_grad()
'''
反向传播:计算模型的损失值对模型中可训练参数的梯度
每个参数是否可训练取决于require_grad的布尔值
所以此操作可以计算所有可训练参数的梯度
'''
loss.backward()
#使用梯度下降进行更新
#利用for循环取出model中的parameters()
#在对param.data进行操作
with torch.no_grad():
for param in model.parameters():
param.data -= learning_rate * param.grad
Pytorch 初次使用nn包的更多相关文章
- PyTorch 中,nn 与 nn.functional 有什么区别?
作者:infiniteft链接:https://www.zhihu.com/question/66782101/answer/579393790来源:知乎著作权归作者所有.商业转载请联系作者获得授权, ...
- pytorch中torch.nn构建神经网络的不同层的含义
主要是参考这里,写的很好PyTorch 入门实战(四)--利用Torch.nn构建卷积神经网络 卷积层nn.Con2d() 常用参数 in_channels:输入通道数 out_channels:输出 ...
- [pytorch笔记] torch.nn vs torch.nn.functional; model.eval() vs torch.no_grad(); nn.Sequential() vs nn.moduleList
1. torch.nn与torch.nn.functional之间的区别和联系 https://blog.csdn.net/GZHermit/article/details/78730856 nn和n ...
- pytorch中的nn.CrossEntropyLoss()
nn.CrossEntropyLoss()这个损失函数和我们普通说的交叉熵还是有些区别 x是模型生成的结果,class是对应的label 具体代码可参见如下 import torch import t ...
- Pytorch并行计算:nn.parallel.replicate, scatter, gather, parallel_apply
import torch import torch.nn as nn import ipdb class DataParallelModel(nn.Module): def __init__(self ...
- pytorch函数之nn.Linear
class torch.nn.Linear(in_features,out_features,bias = True )[来源] 对传入数据应用线性变换:y = A x+ b 参数: in_featu ...
- pytorch 损失函数(nn.BCELoss 和 nn.CrossEntropyLoss)(思考多标签分类问题)
一.BCELoss 二分类损失函数 输入维度为(n, ), 输出维度为(n, ) 如果说要预测二分类值为1的概率,则建议用该函数! 输入比如是3维,则每一个应该是在0--1区间内(随意通常配合sigm ...
- 深度学习之PyTorch实战(2)——神经网络模型搭建和参数优化
上一篇博客先搭建了基础环境,并熟悉了基础知识,本节基于此,再进行深一步的学习. 接下来看看如何基于PyTorch深度学习框架用简单快捷的方式搭建出复杂的神经网络模型,同时让模型参数的优化方法趋于高效. ...
- [PyTorch入门]之从示例中学习PyTorch
Learning PyTorch with examples 来自这里. 本教程通过自包含的示例来介绍PyTorch的基本概念. PyTorch的核心是两个主要功能: 可在GPU上运行的,类似于num ...
随机推荐
- C语言:判断t所指字符串中的字母是否由连续递增字母组成。-判断一个输入的任何整数n,是否等于某个连续正整数序列之和。-将一副扑克牌编号为1到54,以某种方式洗牌,这种方式是将这副牌分成两半,然后将他们交叉,并始终保持编号1的牌在最上方。
//判断t所指字符串中的字母是否由连续递增字母组成. #include <stdio.h> #include <string.h> void NONO(); int fun( ...
- Spring 事务管理的API
Spring事务管理有3个API,均为接口. (1)PlatformTransactionManager 平台事务管理器 常用的实现类: DataSourceTransactionManager ...
- replace() 方法用在字符串中用一些字符替换另一些字符实例
后台给返回的格式是这样的 控制台打印出来格式是这样的 现在需要将这个字符串的数据显示在界面上,1-网站:2-APP:3-客户端 for(var i = 0; i < list.length; i ...
- 吴裕雄 Bootstrap 前端框架开发——Bootstrap 网格系统实例:堆叠的水平
<!DOCTYPE html> <html> <head> <meta charset="utf-8"> <title> ...
- 为什么ISR4K、ASR1K等设备的QoS ACL没有显示计数?
思科的ISR4K和ASR1K设备都是IOS XE的架构,它们和传统的IOS架构是不一样的. 以ISR4K为例,和一般的IOS(例如ISR G2)有所区别,他的转发更依赖硬件完成,针对NAT或QoS应用 ...
- Mac 下 vim 常用命令
vim 三种模式:命令模式.插入模式.底线命令模式. 切换模式: 命令模式: 启动 vim 进入命令模式: i 切换到插入模式,以输入字符. x 删除当前光标所在处的字符. : 切换到底线命令 ...
- 高级T-SQL进阶系列 (一)【上篇】:使用 CROSS JOIN 介绍高级T-SQL
[译注:此文为翻译,由于本人水平所限,疏漏在所难免,欢迎探讨指正] 原文连接:传送门 这是一个新进阶系列的第一篇文章,我们将浏览Transact-SQL(T-SQL)的更多高级特性.这个进阶系列将会包 ...
- DHCP与DHCP中继原理与配置!(重点)
一 .DHCP 服务概述 0:dhcp原理: 集中的管理.分配IP地址,使client动态的获得IP地址.Gateway地址.DNS服务器地址等信息,并能够提升地址的使用率.简单来说,DHCP就是一 ...
- mysql 官方读写分离方案
mysql 8.0 集群模式下的自动读写分离方案 问题 多主模式下的组复制,看起来挺好,起始问题和限制很多.而且中断一个复制就无法配置了 多主模式下,innodbcluster 等于是无用的,不需要自 ...
- java中关于&0xFF 的问题
最近遇到一个问题,半天也没想明白,byte temp = 0xA0,为什么System.out.println(temp),打印的值为:-96:而System.out.println(temp& ...