多分类问题

课程来源:PyTorch深度学习实践——河北工业大学

《PyTorch深度学习实践》完结合集_哔哩哔哩_bilibili

Softmax

这一讲介绍使用softmax分类器实现多分类问题。

上一节课计算的是二分类问题,也就是输出的label可以分类为0,1两类。只要计算出\(P(y=1)\)的概率,那么\(P(y=0)=1-P(y=1)\);所以只需要计算一种类型的概率即可,也就是只要一个参数。

而在使用MINIST对手写数字进行分类的时候一共是有10个分类的(数字0-9)。

处理方式:视为10个二分类问题(一个label和其他9个label),计算每一个label的概率。如下图所示,但是问题在于

  • 每一个二分类问题的结果是独立的,不能保证10个结果加起来等于1,且无法解决互相抑制的问题。
  • 每一个结果不能保证大于0。

我们希望输出是有竞争关系的,也就是如果有一项很大那么其他项要相对比较小。为了解决上述问题,提出softmax函数,使用结构如下:

Softmax计算公式如下:

\[P(y=i)=\frac{e^{z_i}}{\sum_{j=0}^{K-1}e^{Z_j}},i\in \{0,...,K-1\}
\]

Softmax函数计算简单示例如下:

接下来考虑多分类问题中的损失函数如何定义:和上述BCE基本一致,同样使用交叉熵作为损失函数,定义式如下:

\[Loss(\hat Y,Y)=-Ylog\hat Y
\]

上面这种计算方式也就是NLLLoss,这种损失函数的结构如下:

而NLLLoss损失函数加上Softmax就是交叉熵损失,对应PyTorch中的nn.CrossEntropyLoss(),也就是最后一层的非线性变化不需要进行,直接交给上述损失函数即可。如下图所示:

在Minist数据集上实现多分类问题

Minsit 数据介绍:每一个手写图片都可以看做是一个28x28的矩阵,如下图所示:

总体构建模型并训练还是如上四步,在最后一步加上测试过程。

注:1.在视觉处理中,灰度图可以看做单通道图像,而彩色图像事实上就是RGB三通道的矩阵,在PyTorch中要把构造成通道数量的C放在第一维的三维向量。

2.神经网络训练中尽量将图像矩阵转换为0-1分布的数据

模型结构简图:

代码如下:

import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim # prepare dataset batch_size = 64
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) # 归一化 train_dataset = datasets.MNIST(root='../dataset/mnist/', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
test_dataset = datasets.MNIST(root='../dataset/mnist/', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size) # 定义模型
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.l1 = torch.nn.Linear(784, 512)
self.l2 = torch.nn.Linear(512, 256)
self.l3 = torch.nn.Linear(256, 128)
self.l4 = torch.nn.Linear(128, 64)
self.l5 = torch.nn.Linear(64, 10) def forward(self, x):
x = x.view(-1, 784)
x = F.relu(self.l1(x))
x = F.relu(self.l2(x))
x = F.relu(self.l3(x))
x = F.relu(self.l4(x))
return self.l5(x) # 最后一层不做激活,不进行非线性变换 model = Net() # construct loss and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5) # training cycle forward, backward, update def train(epoch):
running_loss = 0.0
for batch_idx, data in enumerate(train_loader, 0):
inputs, target = data
optimizer.zero_grad()
# 预测结果
outputs = model(inputs)
# 交叉熵
loss = criterion(outputs, target)
loss.backward()
optimizer.step() running_loss += loss.item()
if batch_idx % 300 == 299:
print('[%d, %5d] loss: %.3f' % (epoch+1, batch_idx+1, running_loss/300))
running_loss = 0.0 def test():
correct = 0
total = 0
with torch.no_grad():
for data in test_loader:
images, labels = data
outputs = model(images)
_, predicted = torch.max(outputs.data, dim=1) # dim = 1 列是第0个维度,行是第1个维度
total += labels.size(0)
correct += (predicted == labels).sum().item() # 张量之间的比较运算
print('accuracy on test set: %d %% ' % (100*correct/total)) if __name__ == '__main__':
for epoch in range(10):
train(epoch)
test()

作业

Pytorch详解NLLLoss和CrossEntropyLoss 详见如下网址:https://blog.csdn.net/weixin_43593330/article/details/108622747

Otto Group Product Classification Challenge

import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torch.optim as optim ##str转数值类型
def label2id(labels):
id=[]
target_labels=['Class_1', 'Class_2', 'Class_3', 'Class_4', 'Class_5', 'Class_6', 'Class_7', 'Class_8', 'Class_9']
for label in labels:
id.append(target_labels.index(label))
return id
class MyDataset(Dataset):
def __init__(self,filepath):
data=pd.read_csv(filepath)
labels=data['target']
self.x_data=torch.from_numpy(np.array(data)[:,1:-1].astype(np.float32))
self.y_data=label2id(labels)
self.len=data.shape[0]
def __getitem__(self,index):
return self.x_data[index],self.y_data[index]
def __len__(self):
return self.len class Module(nn.Module):
def __init__(self):
super(Module,self).__init__()
self.linear1=nn.Linear(93,64)
self.linear2 = nn.Linear(64, 32)
self.linear3 = nn.Linear(32, 16)
self.linear4 = nn.Linear(16, 9)
self.activate=nn.ReLU() def forward(self,x):
x = self.activate(self.linear1(x))
x = self.activate(self.linear2(x))
x = self.activate(self.linear3(x))
x=self.linear4(x)
return x
def train(epoch):
running_loss=0.0
for batch_idx,data in enumerate(train_loader,1):
x,y=data
y_pred=model(x)
loss=critetion(y_pred,y)
running_loss+=loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_list.append(loss.item())
if batch_idx %100 ==0:
print('[%d, %5d] loss = %.3f' % (epoch + 1, batch_idx, running_loss / 100))
running_loss=0.0 if __name__=="__main__":
train_data = MyDataset('train.csv')
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True, num_workers=0)
model=Module()
critetion=nn.CrossEntropyLoss()
optimizer=optim.SGD(model.parameters(),lr=0.01,momentum=0.5)
loss_list=[]
for epoch in range(30):
train(epoch)
plt.plot(range(len(loss_list)), loss_list)
plt.xlabel('step')
plt.ylabel('loss')
plt.show() def test():
test_data=pd.read_csv('test.csv')
x_test=torch.from_numpy(np.array(test_data)[:,1:].astype(np.float32))
y_pred=model(x_test)
_,pred=torch.max(y_pred,dim=1)
out=pd.get_dummies(pred)#获取one-hot,其实就是0-8
labels=['Class_1', 'Class_2', 'Class_3', 'Class_4', 'Class_5', 'Class_6', 'Class_7', 'Class_8', 'Class_9']
out.columns=labels
out.insert(0,'id',test_data['id'])
result=pd.DataFrame(out)
result.to_csv('otto-group-product_predictions.csv', index=False)
test()

loss可视化:

结果:

PyTorch深度学习实践——多分类问题的更多相关文章

  1. PyTorch深度学习实践——处理多维特征的输入

    处理多维特征的输入 课程来源:PyTorch深度学习实践--河北工业大学 <PyTorch深度学习实践>完结合集_哔哩哔哩_bilibili 这一讲介绍输入为多维数据时的分类. 一个数据集 ...

  2. PyTorch深度学习实践——反向传播

    反向传播 课程来源:PyTorch深度学习实践--河北工业大学 <PyTorch深度学习实践>完结合集_哔哩哔哩_bilibili 目录 反向传播 笔记 作业 笔记 在之前课程中介绍的线性 ...

  3. PyTorch深度学习实践-Overview

    Overview 1.PyTorch简介 ​ PyTorch是一个基于Torch的Python开源机器学习库,用于自然语言处理等应用程序.它主要由Facebookd的人工智能小组开发,不仅能够 实现强 ...

  4. 万字总结Keras深度学习中文文本分类

    摘要:文章将详细讲解Keras实现经典的深度学习文本分类算法,包括LSTM.BiLSTM.BiLSTM+Attention和CNN.TextCNN. 本文分享自华为云社区<Keras深度学习中文 ...

  5. 深度学习实践系列(2)- 搭建notMNIST的深度神经网络

    如果你希望系统性的了解神经网络,请参考零基础入门深度学习系列,下面我会粗略的介绍一下本文中实现神经网络需要了解的知识. 什么是深度神经网络? 神经网络包含三层:输入层(X).隐藏层和输出层:f(x) ...

  6. 深度学习实践系列(3)- 使用Keras搭建notMNIST的神经网络

    前期回顾: 深度学习实践系列(1)- 从零搭建notMNIST逻辑回归模型 深度学习实践系列(2)- 搭建notMNIST的深度神经网络 在第二篇系列中,我们使用了TensorFlow搭建了第一个深度 ...

  7. 对比学习:《深度学习之Pytorch》《PyTorch深度学习实战》+代码

    PyTorch是一个基于Python的深度学习平台,该平台简单易用上手快,从计算机视觉.自然语言处理再到强化学习,PyTorch的功能强大,支持PyTorch的工具包有用于自然语言处理的Allen N ...

  8. 医学图像 | 使用深度学习实现乳腺癌分类(附python演练)

    乳腺癌是全球第二常见的女性癌症.2012年,它占所有新癌症病例的12%,占所有女性癌症病例的25%. 当乳腺细胞生长失控时,乳腺癌就开始了.这些细胞通常形成一个肿瘤,通常可以在x光片上直接看到或感觉到 ...

  9. 深度学习实践系列(1)- 从零搭建notMNIST逻辑回归模型

    MNIST 被喻为深度学习中的Hello World示例,由Yann LeCun等大神组织收集的一个手写数字的数据集,有60000个训练集和10000个验证集,是个非常适合初学者入门的训练集.这个网站 ...

随机推荐

  1. Django之ajax(jquery)封装(包含 将 csrftoken 写入请求头方法)

    由于支持问题,未使用 es6 语法 _ajax.js /** * 发起请求 * @param url 请求地址 * @param data 请求数据 { } json格式 * @param type ...

  2. java中静态代码块详解

    感谢大佬:https://blog.csdn.net/qq_35868412/article/details/89360250 今天在项目中看到这行代码,静态代码块,很久没用静态代码块了,今天来复习一 ...

  3. Java--面向对象设计

    [转载自本科老师上课课件] 问题一: 在一个软件的功能模块中,需要一种图像处理的功能.该图像处理的策略(如何处理)与图像的内容是相关的.如:卫星的运行图片,使用策略A处理方式,如果是卫星内云图片,则需 ...

  4. Aselenium前言

    https://seleniumhq.github.io/docs/index.html https://www.seleniumhq.org/ THE SELENIUM BROWSER AUTOMA ...

  5. Dubbo原理解析(非常透彻)

    一.概述 dubbo是一款经典的rpc框架,用来远程调用服务的. dubbo的作用: 面向接口的远程方法调用 智能容错和负载均衡 服务自动注册和发现. 自定义序列化协议 Dubbo 架构中的核心角色有 ...

  6. Solution -「AGC 016F」Games on DAG

    \(\mathcal{Description}\)   Link.   给定一个含 \(n\) 个点 \(m\) 条边的 DAG,有两枚初始在 1 号点和 2 号点的棋子.两人博弈,轮流移动其中一枚棋 ...

  7. 5.Flink实时项目之业务数据准备

    1. 流程介绍 在上一篇文章中,我们已经把客户端的页面日志,启动日志,曝光日志分别发送到kafka对应的主题中.在本文中,我们将把业务数据也发送到对应的kafka主题中. 通过maxwell采集业务数 ...

  8. suse 12 利用缓存创建本地源供内网服务使用

    文章目录 服务端获取 添加源 刷新源 清除缓存 安装软件 获取rpm包 客户端测试 zypper --help 前言: 其实,咱也不知道为啥写了这篇博客,咱就是想学一学suse,咱也不会,咱也只能学, ...

  9. suse 12 脚本部署docker(二进制文件)

    suse-linux:~ # cat /etc/issue Welcome to SUSE Linux Enterprise Server 12 SP3 (x86_64) - Kernel \r (\ ...

  10. 本地虚拟机在NAT网络连接模式下如何设置才可以访问外网以及使用Xshell远程连接

    本文演示环境: 笔记本电脑系统:windows 7 虚拟机系统:CentOS 7 虚拟化软件:VMware Workstation 12 远程连接工具:Xshell 5 第一步: 打开虚拟网络编辑器 ...