pytorch学习笔记(9)--损失函数
1、损失函数的作用:
(1)计算实际输出和目标输出之间的差距;
(2)为我们更新输出提供一定的依据(也就是反向传播)
官网链接:https://pytorch.org/docs/1.8.1/nn.html
2、损失函数的使用
2.1、L1Loss
注:reduction = “sum” 表示求和 / reduction = "mean" 表示求平均值 默认求平均值
代码:
# file : nn_lose.py
# time : 2022/8/2 上午10:31
# function : L1Loss
import torch
from torch.nn import L1Loss
from torch import nn inputs = torch.tensor([1, 2, 3], dtype=torch.float32)
targets = torch.tensor([1, 2, 5], dtype=torch.float32) # reshape()添加维度,原来tensor是二维
inputs = torch.reshape(inputs, (1, 1, 1, 3))
targets = torch.reshape(targets, (1, 1, 1, 3)) loss = L1Loss()
result = loss(inputs, targets)
print(result)
上述代码计算了实际输出[1, 2, 3]和目标输出[1, 2, 5]之间的L1Loss,代码输出结果为:
tensor(0.6667)
2.2MSELoss 均方损失函数:可以设置reduction参数来决定具体的计算方法
代码:
# file : nn_lose.py
# time : 2022/8/2 上午10:31
# function :
import torch
from torch.nn import L1Loss
from torch import nn inputs = torch.tensor([1, 2, 3], dtype=torch.float32)
targets = torch.tensor([1, 2, 5], dtype=torch.float32) # reshape
inputs = torch.reshape(inputs, (1, 1, 1, 3))
targets = torch.reshape(targets, (1, 1, 1, 3)) loss = L1Loss(reduction="sum")
result = loss(inputs, targets)
print(result) # MSELoss 均方损失函数
loss_mse = nn.MSELoss(reduction="sum")
result_mse = loss_mse(inputs, targets)
print(result_mse)
结果
:
tensor(2.)
tensor(4.)#均方误差损失函数计算结果
2.3 CrossEntropyLoss交叉熵损失函数----没懂
交叉熵损失函数计算方法的细节可以参照这个博文:交叉熵损失函数。(看上去很牛)
代码:
# file : nn_lose.py
# time : 2022/8/2 上午10:31
# function :
import torch
from torch.nn import L1Loss
from torch import nn inputs = torch.tensor([1, 2, 3], dtype=torch.float32)
targets = torch.tensor([1, 2, 5], dtype=torch.float32) # reshape
inputs = torch.reshape(inputs, (1, 1, 1, 3))
targets = torch.reshape(targets, (1, 1, 1, 3)) loss = L1Loss(reduction="sum")
result = loss(inputs, targets)
print(result) # MSELoss 均方损失函数
loss_mse = nn.MSELoss(reduction="sum")
result_mse = loss_mse(inputs, targets)
print(result_mse) # CrossEntropyLoss
x = torch.tensor([0.1, 0.2, 0.3])
y = torch.tensor([1])
x = torch.reshape(x, (1, 3))
loss_cross = nn.CrossEntropyLoss()
result_cross = loss_cross(x, y)
print(result_cross)
结果
:
tensor(1.1019)
用了之前的一个简单神经网络,测试了损失函数及反向传播
# file : nn_loss_network.py
# time : 2022/8/2 下午2:39
# function :
import torch
import torchvision.datasets
from torch import nn
from torch.nn import Conv2d, MaxPool2d, Flatten, Linear, Sequential
from torch.utils.data import DataLoader dataset = torchvision.datasets.CIFAR10("../dataset", train=False, transform=torchvision.transforms.ToTensor(), download=False)
dataloader = DataLoader(dataset, batch_size=1) class Tudui(nn.Module):
def __init__(self):
super(Tudui, self).__init__()
self.model1 = Sequential(
Conv2d(3, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 32, 5, padding=2),
MaxPool2d(2),
Conv2d(32, 64, 5, padding=2),
MaxPool2d(2),
Flatten(),
Linear(1024, 64),
Linear(64, 64)
) def forward(self, x):
x = self.model1(x)
return x loss = nn.CrossEntropyLoss()
tudui = Tudui()
for data in dataloader:
imgs, targets = data
outputs = tudui(imgs)
result_loss = loss(outputs, targets)
result_loss.backward()
print("ok")
pytorch学习笔记(9)--损失函数的更多相关文章
- [PyTorch 学习笔记] 4.2 损失函数
本章代码: https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson4/loss_function_1.py https:// ...
- 【pytorch】pytorch学习笔记(一)
原文地址:https://pytorch.org/tutorials/beginner/deep_learning_60min_blitz.html 什么是pytorch? pytorch是一个基于p ...
- Pytorch学习笔记(一)——简介
一.Tensor Tensor是Pytorch中重要的数据结构,可以认为是一个高维数组.Tensor可以是一个标量.一维数组(向量).二维数组(矩阵)或者高维数组等.Tensor和numpy的ndar ...
- [PyTorch 学习笔记] 1.3 张量操作与线性回归
本章代码:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson1/linear_regression.py 张量的操作 拼 ...
- [PyTorch 学习笔记] 1.1 PyTorch 简介与安装
PyTorch 的诞生 2017 年 1 月,FAIR(Facebook AI Research)发布了 PyTorch.PyTorch 是在 Torch 基础上用 python 语言重新打造的一款深 ...
- [PyTorch 学习笔记] 4.3 优化器
本章代码: https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson4/optimizer_methods.py https: ...
- Pytorch学习笔记(二)---- 神经网络搭建
记录如何用Pytorch搭建LeNet-5,大体步骤包括:网络的搭建->前向传播->定义Loss和Optimizer->训练 # -*- coding: utf-8 -*- # Al ...
- Pytorch学习笔记(一)---- 基础语法
书上内容太多太杂,看完容易忘记,特此记录方便日后查看,所有基础语法以代码形式呈现,代码和注释均来源与书本和案例的整理. # -*- coding: utf-8 -*- # All codes and ...
- 【深度学习】Pytorch 学习笔记
目录 Pytorch Leture 05: Linear Rregression in the Pytorch Way Logistic Regression 逻辑回归 - 二分类 Lecture07 ...
- [PyTorch 学习笔记] 1.4 计算图与动态图机制
本章代码:https://github.com/zhangxiann/PyTorch_Practice/blob/master/lesson1/computational_graph.py 计算图 深 ...
随机推荐
- Unity组件Toggle详解
1.首先先搭建UI(如下图) 2.如果实现单选的功能需要在Image上面挂载ToggleGroup脚本组件 2.2 选中三个Toggle把ToggleGroup拖到如下图位置即可 2.AllowSwi ...
- Ping命令的基本使用
一.Ping命令的基本使用 在网络中 ping 是一个十分强大的 TCP/IP 工具.我们通常会用它来直接 ping ip 地址,来测试网络的连通情况.它的作用主要为: 1.用来检测网络的连通情况和分 ...
- 我的vim配置相关
谨以此文记录下之前的折腾.(后续可能还会折腾什么) 目标 我的目的很简单,就是希望能有一个启动快速的文本编辑器,可以简单的代码着色,vim键位,简单的文本修改,打开大点的文件不发愁,可以简单的form ...
- VirtualBox上使用qemu和busybear搭建RISCV环境
Step 1:安装一些编译riscv需要的库: sudo apt-get install autoconf automake autotools-dev curl libmpc-dev libmpfr ...
- (1028) 权限,chmod、chgrp、chown详解
https://www.cnblogs.com/Berryxiong/p/6193866.html 例1: $ chgrp - R book /opt/local /book 改变/opt/local ...
- 达芬奇18.1.2安装包下载+软件详细破解安装教程(Win&Mac)
DaVinci Resolve v18.1.2是一款在同一个软件工具中,将剪辑.调色.视觉特效.动态图形和音频后期制作融于一身的解决方案!它采用美观新颖的界面设计,易学易用,能让新手用户快速上手操作, ...
- VS2019编译Qt4.8.7
下载4.8.7源码Index of /archive/qt/4.8/4.8.7 复制mkspecs\win32-msvc2015到mkspecs\win32-msvc2019 修改qmake.conf ...
- GCC gcc 和g++
GCC:GNU Compiler Collection(GUN编译器集合),它可以编译C,C++,Java,Fortran,Pascal,Object-C等语言. gcc是GCC中的GUN C Com ...
- inux配置PATH路径
查看PATH:echo $PATH以添加python3为列 修改方法一:export PATH=P A T H : PATH:PATH:HOME/bin:export PATH=P A T H : P ...
- ncnn 加载 bin文件时,出错 报异常 0xC0000094:Integer division by zero。
这次转yolov8.pt 到 onnx 到 ncnn,调用ncnn,加载bin文件时出错报异常 0xC0000094:Integer division by zero. 解决方式: 导出onnx时,加 ...