项目来源

B站视频pytorch项目实战-情感分类问题

github lesson49-情感分类实战


1 实验环境

在这里和大家推荐一个学习ML和DL的一个实验运行平台,就是google的Colaboratory,或者说一个白嫖GPU的实验平台。

大家直接在google搜colab就好,登入账号就可以用了。

什么是 Colaboratory?

借助 Colaboratory(简称 Colab),您可在浏览器中编写和执行 Python 代码,并且:

  • 无需任何配置
  • 免费使用 GPU
  • 轻松共享

无论您是一名学生数据科学家还是 AI 研究员,Colab 都能够帮助您更轻松地完成工作。您可以观看 Colab 简介了解详情,或查看入门指南!

对于 Colab 笔记本,您可以将可执行代码富文本以及图像HTMLLaTeX 等内容合入 1 个文档中。当您创建自己的 Colab 笔记本时,系统会将这些笔记本存储在您的 Google 云端硬盘帐号名下。您可以轻松地将 Colab 笔记本共享给同事或好友,允许他们评论甚至修改笔记本。要了解详情,请参阅 Colab 概览。要创建新的 Colab 笔记本,您可以使用上方的“文件”菜单,也可以使用以下链接:创建新的 Colab 笔记本

Colab 笔记本是由 Colab 托管的 Jupyter 笔记本。如需详细了解 Jupyter 项目,请访问 jupyter.org

使用过程中记得在 菜单栏>代码执行程序>更改运行时类型 中打开使用GPU加速

2 实验

2.1 环境配置和导入

!pip install torch
!pip install torchtext
!python -m spacy download en # K80 gpu for 12 hours
import torch
from torch import nn, optim
from torchtext import data, datasets
print('GPU:', torch.cuda.is_available()) torch.manual_seed(123)
Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (1.7.0+cu101)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch) (3.7.4.3)
Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch) (0.16.0)
Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from torch) (0.8)
Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torch) (1.19.4)
Requirement already satisfied: torchtext in /usr/local/lib/python3.6/dist-packages (0.3.1)
Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (from torchtext) (1.7.0+cu101)
Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from torchtext) (4.41.1)
Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torchtext) (1.19.4)
Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from torchtext) (2.23.0)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch->torchtext) (3.7.4.3)
Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch->torchtext) (0.16.0)
Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from torch->torchtext) (0.8)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->torchtext) (2.10)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->torchtext) (3.0.4)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->torchtext) (2020.12.5)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->torchtext) (1.24.3)
Requirement already satisfied: en_core_web_sm==2.2.5 from https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.2.5/en_core_web_sm-2.2.5.tar.gz#egg=en_core_web_sm==2.2.5 in /usr/local/lib/python3.6/dist-packages (2.2.5)
Requirement already satisfied: spacy>=2.2.2 in /usr/local/lib/python3.6/dist-packages (from en_core_web_sm==2.2.5) (2.2.4)
Requirement already satisfied: srsly<1.1.0,>=1.0.2 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (1.0.5)
Requirement already satisfied: tqdm<5.0.0,>=4.38.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (4.41.1)
Requirement already satisfied: thinc==7.4.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (7.4.0)
Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (51.0.0)
Requirement already satisfied: numpy>=1.15.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (1.19.4)
Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (1.0.5)
Requirement already satisfied: requests<3.0.0,>=2.13.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (2.23.0)
Requirement already satisfied: blis<0.5.0,>=0.4.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (0.4.1)
Requirement already satisfied: plac<1.2.0,>=0.9.6 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (1.1.3)
Requirement already satisfied: wasabi<1.1.0,>=0.4.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (0.8.0)
Requirement already satisfied: preshed<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (3.0.5)
Requirement already satisfied: catalogue<1.1.0,>=0.0.7 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (1.0.0)
Requirement already satisfied: cymem<2.1.0,>=2.0.2 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (2.0.5)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests<3.0.0,>=2.13.0->spacy>=2.2.2->en_core_web_sm==2.2.5) (1.24.3)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests<3.0.0,>=2.13.0->spacy>=2.2.2->en_core_web_sm==2.2.5) (2.10)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests<3.0.0,>=2.13.0->spacy>=2.2.2->en_core_web_sm==2.2.5) (2020.12.5)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests<3.0.0,>=2.13.0->spacy>=2.2.2->en_core_web_sm==2.2.5) (3.0.4)
Requirement already satisfied: importlib-metadata>=0.20; python_version < "3.8" in /usr/local/lib/python3.6/dist-packages (from catalogue<1.1.0,>=0.0.7->spacy>=2.2.2->en_core_web_sm==2.2.5) (3.3.0)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata>=0.20; python_version < "3.8"->catalogue<1.1.0,>=0.0.7->spacy>=2.2.2->en_core_web_sm==2.2.5) (3.4.0)
Requirement already satisfied: typing-extensions>=3.6.4; python_version < "3.8" in /usr/local/lib/python3.6/dist-packages (from importlib-metadata>=0.20; python_version < "3.8"->catalogue<1.1.0,>=0.0.7->spacy>=2.2.2->en_core_web_sm==2.2.5) (3.7.4.3)
Download and installation successful
You can now load the model via spacy.load('en_core_web_sm')
Linking successful
/usr/local/lib/python3.6/dist-packages/en_core_web_sm -->
/usr/local/lib/python3.6/dist-packages/spacy/data/en
You can now load the model via spacy.load('en')
GPU: True
<torch._C.Generator at 0x7f7acf579b10>

2.2 设置数据集

TEXT = data.Field(tokenize='spacy')
LABEL = data.LabelField(dtype=torch.float)
train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)
# IMDB是torchtext提供的数据集
print('len of train data:', len(train_data))
print('len of test data:', len(test_data))
len of train data: 25000
len of test data: 25000
['First', 'I', 'was', 'caught', 'totally', 'off', 'guard', 'by', 'the', 'film', "'s", 'initial', 'lyricism', 'and', 'then', 'I', 'became', 'totally', 'enchanted', 'with', 'the', 'unfolding', 'story', 'and', 'engrossed', 'with', 'the', 'brilliant', 'directing', '.', 'The', 'characters', 'were', 'all', 'fully', 'developed', ',', 'not', 'bigger', '-', 'than', '-', 'life', 'but', 'just', 'like', 'the', 'people', 'we', 'live', 'among', 'anywhere', 'we', 'are', 'in', 'the', 'world', ',', 'in', 'Sweden', ',', 'in', 'Turkey', 'or', 'in', 'America', ',', 'all', 'completely', 'believable', 'human', 'beings', 'with', 'foibles', 'and', 'nobility', '.', 'Hollywood', 'could', 'learn', 'so', 'much', 'from', 'this', 'beautiful', 'film', '.', 'It', 'shows', 'that', 'there', 'is', 'no', 'need', 'to', 'go', 'into', 'every', 'little', 'detail', 'behind', 'every', 'action', 'to', 'bring', 'out', 'the', 'whole', 'theme', 'clear', 'and', 'bright', ',', 'and', 'that', 'shows', 'the', 'brilliance', 'of', 'the', 'director', '!', 'Hearfelt', 'thanks', 'to', 'Kay', 'Pollak', 'and', 'the', 'wonderful', 'cast', 'for', 'this', 'superb', 'treat', '!', '!']
pos
# word2vec, glove
TEXT.build_vocab(train_data, max_size=10000, vectors='glove.6B.100d')
LABEL.build_vocab(train_data) batchsz = 30
device = torch.device('cuda')
train_iterator, test_iterator = data.BucketIterator.splits(
(train_data, test_data),
batch_size = batchsz,
device=device
)
.vector_cache/glove.6B.zip: 862MB [06:28, 2.22MB/s]
100%|█████████▉| 398704/400000 [00:16<00:00, 24622.99it/s]

2.3 搭建lstm网络

class RNN(nn.Module):

    def __init__(self, vocab_size, embedding_dim, hidden_dim):
"""
"""
super(RNN, self).__init__() # [0-10001] => [100] vocab_size=10002 embedding_dim=100,就是说10002个单词其中是10000个真的单词还有一个是不认识的单侧另一个是特殊符号,每个单词用长度100的向量表示
self.embedding = nn.Embedding(vocab_size, embedding_dim)
#[Embedding介绍1](https://zhuanlan.zhihu.com/p/53194407)
#[Embedding介绍2](https://www.cnblogs.com/USTC-ZCC/p/11068791.html)
# [100] => [256]
self.rnn = nn.LSTM(embedding_dim, hidden_dim, num_layers=2,
bidirectional=True, dropout=0.5)
# [双向循环神经网络bidirectional介绍](https://shenxiaohai.me/2018/10/19/pytorch-tutorial-intermediate-04/)
# [256*2] => [1]
self.fc = nn.Linear(hidden_dim*2, 1)
self.dropout = nn.Dropout(0.5) def forward(self, x):
"""
x: [seq_len, b] vs [b, 3, 28, 28]
"""
# [seq, b, 1] => [seq, b, 100]
embedding = self.dropout(self.embedding(x)) # output: [seq, b, hid_dim*2]
# hidden/h: [num_layers*2, b, hid_dim]
# cell/c: [num_layers*2, b, hid_di]
output, (hidden, cell) = self.rnn(embedding) # [num_layers*2, b, hid_dim] => 2 of [b, hid_dim] => [b, hid_dim*2]
hidden = torch.cat([hidden[-2], hidden[-1]], dim=1) # [b, hid_dim*2] => [b, 1]
hidden = self.dropout(hidden)
out = self.fc(hidden) return out

2.4 embedding和网络优化

rnn = RNN(len(TEXT.vocab), 100, 256)
# 转换成embedding的形式
pretrained_embedding = TEXT.vocab.vectors
print('pretrained_embedding:', pretrained_embedding.shape)
rnn.embedding.weight.data.copy_(pretrained_embedding)
print('embedding layer inited.') optimizer = optim.Adam(rnn.parameters(), lr=1e-3)
criteon = nn.BCEWithLogitsLoss().to(device)
rnn.to(device)
pretrained_embedding: torch.Size([10002, 100])
embedding layer inited.
RNN(
(embedding): Embedding(10002, 100)
(rnn): LSTM(100, 256, num_layers=2, dropout=0.5, bidirectional=True)
(fc): Linear(in_features=512, out_features=1, bias=True)
(dropout): Dropout(p=0.5, inplace=False)
)

2.5 训练与测试

import numpy as np

def binary_acc(preds, y):
"""
get accuracy
"""
preds = torch.round(torch.sigmoid(preds))
correct = torch.eq(preds, y).float()
acc = correct.sum() / len(correct)
return acc def train(rnn, iterator, optimizer, criteon): avg_acc = []
rnn.train() for i, batch in enumerate(iterator): # 遍历所有训练数据 # [seq, b] => [b, 1] => [b]
pred = rnn(batch.text).squeeze(1)
#
loss = criteon(pred, batch.label)
acc = binary_acc(pred, batch.label).item()
avg_acc.append(acc) optimizer.zero_grad()
loss.backward()
optimizer.step() if i%100 == 0:
print(i, acc) avg_acc = np.array(avg_acc).mean()
print('avg acc:', avg_acc) def eval(rnn, iterator, criteon): avg_acc = [] rnn.eval() with torch.no_grad():
for batch in iterator: # [b, 1] => [b]
pred = rnn(batch.text).squeeze(1) #
loss = criteon(pred, batch.label) acc = binary_acc(pred, batch.label).item()
avg_acc.append(acc) avg_acc = np.array(avg_acc).mean() print('>>test:', avg_acc)
for epoch in range(10):

    eval(rnn, test_iterator, criteon)
train(rnn, train_iterator, optimizer, criteon)
>>test: 0.8730615999915903
0 0.8666667342185974
100 0.8666667342185974
200 0.9666666984558105
300 0.9333333969116211
400 0.9333333969116211
500 0.9666666984558105
600 0.9000000357627869
700 0.9666666984558105
800 1.0
avg acc: 0.9348521599952552
>>test: 0.8765388191175117
0 0.8666667342185974
100 0.9000000357627869
200 1.0
300 0.9333333969116211
400 0.9333333969116211
500 0.9333333969116211
600 0.9666666984558105
700 0.8666667342185974
800 0.9000000357627869
avg acc: 0.9394085123527536
>>test: 0.8712630401984107
0 1.0
100 0.8666667342185974
200 0.9666666984558105
300 0.9666666984558105
400 1.0
500 0.9666666984558105
600 0.9666666984558105
700 0.9666666984558105
800 0.8333333730697632
avg acc: 0.94452441853585
>>test: 0.8790967720303889
0 0.9666666984558105
100 1.0
200 0.9666666984558105
300 0.9000000357627869
400 0.9333333969116211
500 0.9000000357627869
600 0.9000000357627869
700 0.9666666984558105
800 1.0
avg acc: 0.9481215391942351
>>test: 0.8758193941996824
0 0.9333333969116211
100 0.9666666984558105
200 1.0
300 0.9333333969116211
400 0.9666666984558105
500 0.9666666984558105
600 0.9666666984558105
700 0.9666666984558105
800 0.9333333969116211
avg acc: 0.9529176998338539
>>test: 0.8762590416329656
0 0.8666667342185974
100 0.9333333969116211
200 1.0
300 0.9666666984558105
400 0.8666667342185974
500 0.9333333969116211
600 0.9666666984558105
700 1.0
800 0.8666667342185974
avg acc: 0.9550360044558271
>>test: 0.8747402563941279
0 0.9666666984558105
100 0.9333333969116211
200 0.9666666984558105
300 0.9666666984558105
400 0.9333333969116211
500 0.9333333969116211
600 0.9333333969116211
700 0.9666666984558105
800 0.9333333969116211
avg acc: 0.958473253521702
>>test: 0.8732214720843793
0 0.9333333969116211
100 0.9666666984558105
200 0.9666666984558105
300 0.9666666984558105
400 0.9666666984558105
500 0.9666666984558105
600 1.0
700 0.9666666984558105
800 0.9333333969116211
avg acc: 0.9630296058792005
>>test: 0.8703038053546878
0 0.9666666984558105
100 0.9333333969116211
200 1.0
300 1.0
400 1.0
500 0.9333333969116211
600 0.9666666984558105
700 1.0
800 0.9666666984558105
avg acc: 0.965107941155811
>>test: 0.8725819842849704
0 1.0
100 1.0
200 0.9666666984558105
300 1.0
400 1.0
500 1.0
600 0.9333333969116211
700 1.0
800 0.9000000357627869
avg acc: 0.9668265646881908

薄书的pytorch项目实战lesson49-情感分类+蹭免费GPU的更多相关文章

  1. Java 架构师+高并发+性能优化+Spring boot大型分布式项目实战

    视频课程内容包含: 高级 Java 架构师包含:Spring boot.Spring cloud.Dubbo.Redis.ActiveMQ.Nginx.Mycat.Spring.MongoDB.Zer ...

  2. 深度学习--LSTM网络、使用方法、实战情感分类问题

    深度学习--LSTM网络.使用方法.实战情感分类问题 1.LSTM基础 长短期记忆网络(Long Short-Term Memory,简称LSTM),是RNN的一种,为了解决RNN存在长期依赖问题而设 ...

  3. PGL图学习之项目实践(UniMP算法实现论文节点分类、新冠疫苗项目实战,助力疫情)[系列九]

    原项目链接:https://aistudio.baidu.com/aistudio/projectdetail/5100049?contributionType=1 1.图学习技术与应用 图是一个复杂 ...

  4. pytorch 文本情感分类和命名实体识别NER中LSTM输出的区别

    文本情感分类: 文本情感分类采用LSTM的最后一层输出 比如双层的LSTM,使用正向的最后一层和反向的最后一层进行拼接 def forward(self,input): ''' :param inpu ...

  5. 【腾讯Bugly干货分享】React Native项目实战总结

    本文来自于腾讯bugly开发者社区,非经作者同意,请勿转载,原文地址:http://dev.qq.com/topic/577e16a7640ad7b4682c64a7 “8小时内拼工作,8小时外拼成长 ...

  6. 【无私分享:ASP.NET CORE 项目实战(第十三章)】Asp.net Core 使用MyCat分布式数据库,实现读写分离

    目录索引 [无私分享:ASP.NET CORE 项目实战]目录索引 简介 MyCat2.0版本很快就发布了,关于MyCat的动态和一些问题,大家可以加一下MyCat的官方QQ群:106088787.我 ...

  7. Net Core 项目实战之权限管理系统(0)

    0 前言 Net Core 项目实战之权限管理系统(0) 无中生有   0 http://www.cnblogs.com/fonour/p/5848933.html 学习的最好方法就是动手去做,这里以 ...

  8. angularJs项目实战!01:模块划分和目录组织

    近日来我有幸主导了一个典型的web app开发.该项目从产品层次来说是个典型的CRUD应用,故而我毫不犹豫地采用了grunt + boilerplate + angularjs + bootstrap ...

  9. 项目实战13—企业级虚拟化Virtualization-KVM技术

    项目实战系列,总架构图 http://www.cnblogs.com/along21/p/8000812.html KVM的介绍.准备工作和qemu-kvm 命令详解 1.介绍 (1)介绍 KVM:就 ...

  10. 【.NET Core项目实战-统一认证平台】第十章 授权篇-客户端授权

    [.NET Core项目实战-统一认证平台]开篇及目录索引 上篇文章介绍了如何使用Dapper持久化IdentityServer4(以下简称ids4)的信息,并实现了sqlserver和mysql两种 ...

随机推荐

  1. 重新点亮shell————awk 控制语句[十三]

    前言 简单介绍一下控制语句. 正文 例子1: 例子2: 例子3 for循环: 例子4, sum会复用: 同样,其他的while 和 do while 也是可以在awk中使用的. 结 下一节awk数组.

  2. redis 简单整理——java 客户端jedis[十六]

    前言 简单介绍一下java客户端jedis. 正文 Java有很多优秀的Redis客户端(详见:http://redis.io/clients#java),这 里介绍使用较为广泛的客户端Jedis,本 ...

  3. Worker 进行多线程任务开发

    概念介绍 在 OpenHarmony 中,UI 线程负责处理 UI 事件和用户交互,而 Worker 线程用于处理耗时操作,以提高应用程序的响应速度和用户体验. Worker 线程是与主线程并行的独立 ...

  4. Linux下的常见基本指令

    pwd //显示当前用户所在的路径 ls //显示当前路径下的文件名或者目录名称 ls-l //显示当前路径下的文件或者目录的更详细的属性信息 cd 一个目录路径 //进入一个目录,进去后,可以用pw ...

  5. 力扣1132(MySQL)-报告的记录Ⅱ(中等)

    题目: 编写一段 SQL 来查找:在被报告为垃圾广告的帖子中,被移除的帖子的每日平均占比,四舍五入到小数点后 2 位. Actions 表: Removals 表: Result 表: 2019-07 ...

  6. 从零开始入门 K8s | 理解 CNI 和 CNI 插件

    作者 | 溪恒 阿里巴巴高级技术专家 本文整理自<CNCF x Alibaba 云原生技术公开课>第 26 讲,点击直达课程页面. 关注"阿里巴巴云原生"公众号,回复关 ...

  7. DNS高可用设计--软件高可用

    DNS是网络的基础服务,网络上的各种应用对DNS的依赖性很高.DNS的稳定,直接决定了上层应用服务的稳定.那如何保障DNS服务的高可用呢?我们先来看下高可用的概念: 高可用 高可用(High avai ...

  8. 如何避免 Go 命令行执行产生“孤儿”进程?

    简介: 在 Go 程序当中,如果我们要执行命令时,通常会使用 exec.Command ,也比较好用,通常状况下,可以达到我们的目的,如果我们逻辑当中,需要终止这个进程,则可以快速使用 cmd.Pro ...

  9. [Linux] 日志管理: rsyslog 日志格式 / 配置文件详解

    1. 日志文件格式包含以下四列: 事件时间 | 发生事件的服务器的主机名 | 产生事件的服务名或程序名 | 事件的具体信息 2. /etc/rsyslog.conf 配置文件 # 服务名称 [连接符号 ...

  10. 5.k8s Service四层负载:服务端口暴露

    题目一:暴露服务service 设置配置环境: [candidate@node-1] $ kubectl config use-context k8s Task 请重新配置现有的 deployment ...