多分类问题

课程来源:PyTorch深度学习实践——河北工业大学

《PyTorch深度学习实践》完结合集_哔哩哔哩_bilibili

Softmax

这一讲介绍使用softmax分类器实现多分类问题。

上一节课计算的是二分类问题,也就是输出的label可以分类为0,1两类。只要计算出\(P(y=1)\)的概率,那么\(P(y=0)=1-P(y=1)\);所以只需要计算一种类型的概率即可,也就是只要一个参数。

而在使用MINIST对手写数字进行分类的时候一共是有10个分类的(数字0-9)。

处理方式:视为10个二分类问题(一个label和其他9个label),计算每一个label的概率。如下图所示,但是问题在于

  • 每一个二分类问题的结果是独立的,不能保证10个结果加起来等于1,且无法解决互相抑制的问题。
  • 每一个结果不能保证大于0。

我们希望输出是有竞争关系的,也就是如果有一项很大那么其他项要相对比较小。为了解决上述问题,提出softmax函数,使用结构如下:

Softmax计算公式如下:

\[P(y=i)=\frac{e^{z_i}}{\sum_{j=0}^{K-1}e^{Z_j}},i\in \{0,...,K-1\}
\]

Softmax函数计算简单示例如下:

接下来考虑多分类问题中的损失函数如何定义:和上述BCE基本一致,同样使用交叉熵作为损失函数,定义式如下:

\[Loss(\hat Y,Y)=-Ylog\hat Y
\]

上面这种计算方式也就是NLLLoss,这种损失函数的结构如下:

而NLLLoss损失函数加上Softmax就是交叉熵损失,对应PyTorch中的nn.CrossEntropyLoss(),也就是最后一层的非线性变化不需要进行,直接交给上述损失函数即可。如下图所示:

在Minist数据集上实现多分类问题

Minsit 数据介绍:每一个手写图片都可以看做是一个28x28的矩阵,如下图所示:

总体构建模型并训练还是如上四步,在最后一步加上测试过程。

注:1.在视觉处理中,灰度图可以看做单通道图像,而彩色图像事实上就是RGB三通道的矩阵,在PyTorch中要把构造成通道数量的C放在第一维的三维向量。

2.神经网络训练中尽量将图像矩阵转换为0-1分布的数据

模型结构简图:

代码如下:

import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn.functional as F
import torch.optim as optim # prepare dataset batch_size = 64
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) # 归一化 train_dataset = datasets.MNIST(root='../dataset/mnist/', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=batch_size)
test_dataset = datasets.MNIST(root='../dataset/mnist/', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=batch_size) # 定义模型
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.l1 = torch.nn.Linear(784, 512)
self.l2 = torch.nn.Linear(512, 256)
self.l3 = torch.nn.Linear(256, 128)
self.l4 = torch.nn.Linear(128, 64)
self.l5 = torch.nn.Linear(64, 10) def forward(self, x):
x = x.view(-1, 784)
x = F.relu(self.l1(x))
x = F.relu(self.l2(x))
x = F.relu(self.l3(x))
x = F.relu(self.l4(x))
return self.l5(x) # 最后一层不做激活,不进行非线性变换 model = Net() # construct loss and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5) # training cycle forward, backward, update def train(epoch):
running_loss = 0.0
for batch_idx, data in enumerate(train_loader, 0):
inputs, target = data
optimizer.zero_grad()
# 预测结果
outputs = model(inputs)
# 交叉熵
loss = criterion(outputs, target)
loss.backward()
optimizer.step() running_loss += loss.item()
if batch_idx % 300 == 299:
print('[%d, %5d] loss: %.3f' % (epoch+1, batch_idx+1, running_loss/300))
running_loss = 0.0 def test():
correct = 0
total = 0
with torch.no_grad():
for data in test_loader:
images, labels = data
outputs = model(images)
_, predicted = torch.max(outputs.data, dim=1) # dim = 1 列是第0个维度,行是第1个维度
total += labels.size(0)
correct += (predicted == labels).sum().item() # 张量之间的比较运算
print('accuracy on test set: %d %% ' % (100*correct/total)) if __name__ == '__main__':
for epoch in range(10):
train(epoch)
test()

作业

Pytorch详解NLLLoss和CrossEntropyLoss 详见如下网址:https://blog.csdn.net/weixin_43593330/article/details/108622747

Otto Group Product Classification Challenge

import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import torch.optim as optim ##str转数值类型
def label2id(labels):
id=[]
target_labels=['Class_1', 'Class_2', 'Class_3', 'Class_4', 'Class_5', 'Class_6', 'Class_7', 'Class_8', 'Class_9']
for label in labels:
id.append(target_labels.index(label))
return id
class MyDataset(Dataset):
def __init__(self,filepath):
data=pd.read_csv(filepath)
labels=data['target']
self.x_data=torch.from_numpy(np.array(data)[:,1:-1].astype(np.float32))
self.y_data=label2id(labels)
self.len=data.shape[0]
def __getitem__(self,index):
return self.x_data[index],self.y_data[index]
def __len__(self):
return self.len class Module(nn.Module):
def __init__(self):
super(Module,self).__init__()
self.linear1=nn.Linear(93,64)
self.linear2 = nn.Linear(64, 32)
self.linear3 = nn.Linear(32, 16)
self.linear4 = nn.Linear(16, 9)
self.activate=nn.ReLU() def forward(self,x):
x = self.activate(self.linear1(x))
x = self.activate(self.linear2(x))
x = self.activate(self.linear3(x))
x=self.linear4(x)
return x
def train(epoch):
running_loss=0.0
for batch_idx,data in enumerate(train_loader,1):
x,y=data
y_pred=model(x)
loss=critetion(y_pred,y)
running_loss+=loss.item()
optimizer.zero_grad()
loss.backward()
optimizer.step()
loss_list.append(loss.item())
if batch_idx %100 ==0:
print('[%d, %5d] loss = %.3f' % (epoch + 1, batch_idx, running_loss / 100))
running_loss=0.0 if __name__=="__main__":
train_data = MyDataset('train.csv')
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True, num_workers=0)
model=Module()
critetion=nn.CrossEntropyLoss()
optimizer=optim.SGD(model.parameters(),lr=0.01,momentum=0.5)
loss_list=[]
for epoch in range(30):
train(epoch)
plt.plot(range(len(loss_list)), loss_list)
plt.xlabel('step')
plt.ylabel('loss')
plt.show() def test():
test_data=pd.read_csv('test.csv')
x_test=torch.from_numpy(np.array(test_data)[:,1:].astype(np.float32))
y_pred=model(x_test)
_,pred=torch.max(y_pred,dim=1)
out=pd.get_dummies(pred)#获取one-hot,其实就是0-8
labels=['Class_1', 'Class_2', 'Class_3', 'Class_4', 'Class_5', 'Class_6', 'Class_7', 'Class_8', 'Class_9']
out.columns=labels
out.insert(0,'id',test_data['id'])
result=pd.DataFrame(out)
result.to_csv('otto-group-product_predictions.csv', index=False)
test()

loss可视化:

结果:

PyTorch深度学习实践——多分类问题的更多相关文章

  1. PyTorch深度学习实践——处理多维特征的输入

    处理多维特征的输入 课程来源:PyTorch深度学习实践--河北工业大学 <PyTorch深度学习实践>完结合集_哔哩哔哩_bilibili 这一讲介绍输入为多维数据时的分类. 一个数据集 ...

  2. PyTorch深度学习实践——反向传播

    反向传播 课程来源:PyTorch深度学习实践--河北工业大学 <PyTorch深度学习实践>完结合集_哔哩哔哩_bilibili 目录 反向传播 笔记 作业 笔记 在之前课程中介绍的线性 ...

  3. PyTorch深度学习实践-Overview

    Overview 1.PyTorch简介 ​ PyTorch是一个基于Torch的Python开源机器学习库,用于自然语言处理等应用程序.它主要由Facebookd的人工智能小组开发,不仅能够 实现强 ...

  4. 万字总结Keras深度学习中文文本分类

    摘要:文章将详细讲解Keras实现经典的深度学习文本分类算法,包括LSTM.BiLSTM.BiLSTM+Attention和CNN.TextCNN. 本文分享自华为云社区<Keras深度学习中文 ...

  5. 深度学习实践系列(2)- 搭建notMNIST的深度神经网络

    如果你希望系统性的了解神经网络,请参考零基础入门深度学习系列,下面我会粗略的介绍一下本文中实现神经网络需要了解的知识. 什么是深度神经网络? 神经网络包含三层:输入层(X).隐藏层和输出层:f(x) ...

  6. 深度学习实践系列(3)- 使用Keras搭建notMNIST的神经网络

    前期回顾: 深度学习实践系列(1)- 从零搭建notMNIST逻辑回归模型 深度学习实践系列(2)- 搭建notMNIST的深度神经网络 在第二篇系列中,我们使用了TensorFlow搭建了第一个深度 ...

  7. 对比学习:《深度学习之Pytorch》《PyTorch深度学习实战》+代码

    PyTorch是一个基于Python的深度学习平台,该平台简单易用上手快,从计算机视觉.自然语言处理再到强化学习,PyTorch的功能强大,支持PyTorch的工具包有用于自然语言处理的Allen N ...

  8. 医学图像 | 使用深度学习实现乳腺癌分类(附python演练)

    乳腺癌是全球第二常见的女性癌症.2012年,它占所有新癌症病例的12%,占所有女性癌症病例的25%. 当乳腺细胞生长失控时,乳腺癌就开始了.这些细胞通常形成一个肿瘤,通常可以在x光片上直接看到或感觉到 ...

  9. 深度学习实践系列(1)- 从零搭建notMNIST逻辑回归模型

    MNIST 被喻为深度学习中的Hello World示例,由Yann LeCun等大神组织收集的一个手写数字的数据集,有60000个训练集和10000个验证集,是个非常适合初学者入门的训练集.这个网站 ...

随机推荐

  1. .Net Api 之如何使用Elasticsearch存储文档

    .Net Api 之如何使用Elasticsearch存储文档 什么是Elasticsearch? Elasticsearch 是一个分布式.高扩展.高实时的搜索与数据分析引擎.它能很方便的使大量数据 ...

  2. 边带权并查集 学习笔记 & 洛谷P1196 [NOI2002] 银河英雄传说 题解

    花了2h总算把边带权并查集整明白了qaq 1.边带权并查集的用途 众所周知,并查集擅长维护与可传递关系有关的信息.然而我们有时会发现并查集所维护的信息不够用,这时"边带权并查集"就 ...

  3. tcp|ip nagle算法

    在TCP传输数据流中,存在两种类型的TCP报文段,一种包含成块数据(通常是满长度的,携带一个报文段最多容纳的字节数),另一种则包含交互数据(通常只有携带几个字节数据). 对于成块数据的报文段,TCP采 ...

  4. 「Ynoi2018」未来日记

    「Ynoi2018」未来日记 区间x->y,kth值... 不管了,先序列分块... 查询 第k值,假定知道每个数的权值,对值域分块. 对于整块,维护前\(i\)个块当中,值域在\(j\)块里以 ...

  5. AT3913 XOR Tree

    经过长时间的思考,我发现直接考虑对一条链进行修改是很难做出本题的,可能需要换一个方向. 可以发现本题中有操作的存在,是没有可以反过来做的做法的,因此正难则反这条路应该走不通. 那么唯一的办法就是简化这 ...

  6. 详解Java12新增语法switch表达式

    引言 在学习分支语句的时候,我们都学过 switch 语句,相比于 if-else 语句,他看起来更加整洁,逻辑更加清晰,Java中当然也给我们提了相关的 switch 方法.但是Java的强大之处在 ...

  7. shell脚本命令

    http://man.linuxde.net/shell-script   从键盘或文件中获取标准输入:read命令 文件的描述符和重定向 数组.关联数组和别名的使用 函数的定义.执行.传参和递归函数 ...

  8. 小程序"errcode":41002错误问题如何解决?

    我的问题是:小程序在本地测试的时候是没有问题的,但是当我扫开发者中的项目中的二维码手机浏览测试的时候发现是没有数据的,然后调试工具中出现: {"errcode":41002,&qu ...

  9. ELK-EFK-v7.12.0日志平台部署

    ELK和EFK是什么 ELK和EFK是四个开源产品的组合: Elasticsearch 一个基于Lucene搜索引擎的NoSQL数据库 Logstatsh 一个日志管道工具,接受数据输入,执行数据转换 ...

  10. MySQL架构原理之运行机制

    所谓运行机制即MySQL内部就如生产车间如何进行生产的.如下图: 1.建立连接,通过客户端/服务器通信协议与MySQL建立连接.MySQL客户端与服务端的通信方式是"半双工".对于 ...