pytorch中的nn.CrossEntropyLoss()
nn.CrossEntropyLoss()这个损失函数和我们普通说的交叉熵还是有些区别。
$x$是模型生成的结果,$class$是数据对应的label
$loss(x,class)=-log(\frac{exp(x[class])}{\sum_j exp(x[j])})=-x[class]+log(\sum_j exp(x[j]))$
nn.CrossEntropyLoss()的使用方式参见如下代码
import torch
import torch.nn as nn # 表示模型的输出output(B,C)格式,B是batch,C是类别
output = torch.randn(2, 3, requires_grad = True) #batch_size设置为2,3分类
# 表示数据的标签label(B)格式,B是batch,其中的数值是位于[0,C-1]
label = torch.empty(2, dtype=torch.long).random_(3) # 0 - 2, 任意选取一个分类
print(output)
'''
tensor([[-1.1313, 0.5944, -1.5735],
[ 1.2037, -1.0548, -0.9253]], requires_grad=True)
'''
print(label)#tensor([0, 2])
loss = nn.CrossEntropyLoss()
#先对每个训练样本求损失,而后再求平均损失
print ('loss :', loss(output, label))#loss : tensor(2.1565, grad_fn=<NllLossBackward>)
pytorch中的nn.CrossEntropyLoss()的更多相关文章
- PyTorch 中,nn 与 nn.functional 有什么区别?
作者:infiniteft链接:https://www.zhihu.com/question/66782101/answer/579393790来源:知乎著作权归作者所有.商业转载请联系作者获得授权, ...
- pytorch中torch.nn构建神经网络的不同层的含义
主要是参考这里,写的很好PyTorch 入门实战(四)--利用Torch.nn构建卷积神经网络 卷积层nn.Con2d() 常用参数 in_channels:输入通道数 out_channels:输出 ...
- 交叉熵的数学原理及应用——pytorch中的CrossEntropyLoss()函数
分类问题中,交叉熵函数是比较常用也是比较基础的损失函数,原来就是了解,但一直搞不懂他是怎么来的?为什么交叉熵能够表征真实样本标签和预测概率之间的差值?趁着这次学习把这些概念系统学习了一下. 首先说起交 ...
- PyTorch中ReLU的inplace
0 - inplace 在pytorch中,nn.ReLU(inplace=True)和nn.LeakyReLU(inplace=True)中存在inplace字段.该参数的inplace=True的 ...
- PyTorch 中 weight decay 的设置
先介绍一下 Caffe 和 TensorFlow 中 weight decay 的设置: 在 Caffe 中, SolverParameter.weight_decay 可以作用于所有的可训练参数, ...
- 小白学习之pytorch框架(4)-softmax回归(torch.gather()、torch.argmax()、torch.nn.CrossEntropyLoss())
学习pytorch路程之动手学深度学习-3.4-3.7 置信度.置信区间参考:https://cloud.tencent.com/developer/news/452418 本人感觉还是挺好理解的 交 ...
- pytorch 中的重要模块化接口nn.Module
torch.nn 是专门为神经网络设计的模块化接口,nn构建于autgrad之上,可以用来定义和运行神经网络 nn.Module 是nn中重要的类,包含网络各层的定义,以及forward方法 对于自己 ...
- [转载]Pytorch中nn.Linear module的理解
[转载]Pytorch中nn.Linear module的理解 本文转载并援引全文纯粹是为了构建和分类自己的知识,方便自己未来的查找,没啥其他意思. 这个模块要实现的公式是:y=xAT+*b 来源:h ...
- Pytorch中nn.Dropout2d的作用
Pytorch中nn.Dropout2d的作用 首先,关于Dropout方法,这篇博文有详细的介绍.简单来说, 我们在前向传播的时候,让某个神经元的激活值以一定的概率p停止工作,这样可以使模型泛化性更 ...
随机推荐
- [golang] nats的消息传递模型介绍
目录 nats的消息传递模型 What is NATS 主题式消息(Subject-Based Messaging) 发布订阅(Publish-Subscribe) 请求应答(Request-Repl ...
- Python的面试题
(1)怎么把一个字符串转换成整型? 可以使用int函数 如 int('3') 结果由字符串'3'变为整型3 (2)python内建数据类型有哪些? int .bool. str.list. ru ...
- Python3 类的继承
目录 继承的基本概念 什么是继承 继承有什么用 如何实现继承 初识继承 寻找继承关系 如何寻找继承关系 实例演示 继承背景下的对象属性查找顺序 派生 新式类和经典类 钻石继承 通过继承实现修改json ...
- 【Redis】349- Redis 入门指南
点击上方"前端自习课"关注,学习起来~ 1. 概述 1.1. Redis 简介 Redis 是速度非常快的非关系型(NoSQL)内存键值数据库,可以存储键和五种不同类型的值之间的映 ...
- python数据结构——单向链表
链表 ( Linked List ) 定义:由许多相同数据类型的数据项按照特定顺序排列而成的线性表. 特点:各个数据在计算机中是随机存放且不连续. 优点:数据的增删改查都很方便,当有新的数据加入的时候 ...
- 从多谐振荡器详细解析到555定时器基本电路(控制LED闪烁)
在学期末,笔者参加了学校的电工实习,前六天做都很快,但是今天要做一个关于555多谐振荡器的LED闪烁电路,由于笔者没有提前准备,导致今天就算把电路搭建出来也不懂具体原理,耗费了不少时间,所以我打算专门 ...
- 1篇文章搞清楚8种JVM内存溢出(OOM)的原因和解决方法
前言 撸Java的同学,多多少少会碰到内存溢出(OOM)的场景,但造成OOM的原因却是多种多样. 堆溢出 这种场景最为常见,报错信息: java.lang.OutOfMemoryError: Java ...
- 如何使用charles抓包H5页面内容
安装charles 这里推荐直接去官网下载 https://www.charlesproxy.com/latest-release/download.do 根据自己的电脑选择合适的安装包,我这里选择m ...
- C# DataTable to List<T> based on reflection.
From https://www.cnblogs.com/zjbky/p/9242140.html static class ExtendClass { public static List<T ...
- JS---封装getScroll函数 & 案例:固定导航栏
封装getScroll函数 1. 获取页面向上或者向左卷曲出去的距离的值 2. 浏览器的滚动事件 function getScroll() { return { left: window.pageXO ...