薄书的pytorch项目实战lesson49-情感分类+蹭免费GPU
项目来源
B站视频pytorch项目实战-情感分类问题
github lesson49-情感分类实战
1 实验环境
在这里和大家推荐一个学习ML和DL的一个实验运行平台,就是google的Colaboratory,或者说一个白嫖GPU的实验平台。
大家直接在google搜colab就好,登入账号就可以用了。
什么是 Colaboratory?
借助 Colaboratory(简称 Colab),您可在浏览器中编写和执行 Python 代码,并且:
- 无需任何配置
- 免费使用 GPU
- 轻松共享
无论您是一名学生、数据科学家还是 AI 研究员,Colab 都能够帮助您更轻松地完成工作。您可以观看 Colab 简介了解详情,或查看入门指南!
对于 Colab 笔记本,您可以将可执行代码、富文本以及图像、HTML、LaTeX 等内容合入 1 个文档中。当您创建自己的 Colab 笔记本时,系统会将这些笔记本存储在您的 Google 云端硬盘帐号名下。您可以轻松地将 Colab 笔记本共享给同事或好友,允许他们评论甚至修改笔记本。要了解详情,请参阅 Colab 概览。要创建新的 Colab 笔记本,您可以使用上方的“文件”菜单,也可以使用以下链接:创建新的 Colab 笔记本。
Colab 笔记本是由 Colab 托管的 Jupyter 笔记本。如需详细了解 Jupyter 项目,请访问 jupyter.org。
使用过程中记得在 菜单栏>代码执行程序>更改运行时类型 中打开使用GPU加速
2 实验
2.1 环境配置和导入
!pip install torch
!pip install torchtext
!python -m spacy download en
# K80 gpu for 12 hours
import torch
from torch import nn, optim
from torchtext import data, datasets
print('GPU:', torch.cuda.is_available())
torch.manual_seed(123)
Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (1.7.0+cu101)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch) (3.7.4.3)
Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch) (0.16.0)
Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from torch) (0.8)
Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torch) (1.19.4)
Requirement already satisfied: torchtext in /usr/local/lib/python3.6/dist-packages (0.3.1)
Requirement already satisfied: torch in /usr/local/lib/python3.6/dist-packages (from torchtext) (1.7.0+cu101)
Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from torchtext) (4.41.1)
Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from torchtext) (1.19.4)
Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from torchtext) (2.23.0)
Requirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch->torchtext) (3.7.4.3)
Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch->torchtext) (0.16.0)
Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from torch->torchtext) (0.8)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->torchtext) (2.10)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->torchtext) (3.0.4)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->torchtext) (2020.12.5)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->torchtext) (1.24.3)
Requirement already satisfied: en_core_web_sm==2.2.5 from https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.2.5/en_core_web_sm-2.2.5.tar.gz#egg=en_core_web_sm==2.2.5 in /usr/local/lib/python3.6/dist-packages (2.2.5)
Requirement already satisfied: spacy>=2.2.2 in /usr/local/lib/python3.6/dist-packages (from en_core_web_sm==2.2.5) (2.2.4)
Requirement already satisfied: srsly<1.1.0,>=1.0.2 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (1.0.5)
Requirement already satisfied: tqdm<5.0.0,>=4.38.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (4.41.1)
Requirement already satisfied: thinc==7.4.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (7.4.0)
Requirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (51.0.0)
Requirement already satisfied: numpy>=1.15.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (1.19.4)
Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (1.0.5)
Requirement already satisfied: requests<3.0.0,>=2.13.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (2.23.0)
Requirement already satisfied: blis<0.5.0,>=0.4.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (0.4.1)
Requirement already satisfied: plac<1.2.0,>=0.9.6 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (1.1.3)
Requirement already satisfied: wasabi<1.1.0,>=0.4.0 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (0.8.0)
Requirement already satisfied: preshed<3.1.0,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (3.0.5)
Requirement already satisfied: catalogue<1.1.0,>=0.0.7 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (1.0.0)
Requirement already satisfied: cymem<2.1.0,>=2.0.2 in /usr/local/lib/python3.6/dist-packages (from spacy>=2.2.2->en_core_web_sm==2.2.5) (2.0.5)
Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests<3.0.0,>=2.13.0->spacy>=2.2.2->en_core_web_sm==2.2.5) (1.24.3)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests<3.0.0,>=2.13.0->spacy>=2.2.2->en_core_web_sm==2.2.5) (2.10)
Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests<3.0.0,>=2.13.0->spacy>=2.2.2->en_core_web_sm==2.2.5) (2020.12.5)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests<3.0.0,>=2.13.0->spacy>=2.2.2->en_core_web_sm==2.2.5) (3.0.4)
Requirement already satisfied: importlib-metadata>=0.20; python_version < "3.8" in /usr/local/lib/python3.6/dist-packages (from catalogue<1.1.0,>=0.0.7->spacy>=2.2.2->en_core_web_sm==2.2.5) (3.3.0)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata>=0.20; python_version < "3.8"->catalogue<1.1.0,>=0.0.7->spacy>=2.2.2->en_core_web_sm==2.2.5) (3.4.0)
Requirement already satisfied: typing-extensions>=3.6.4; python_version < "3.8" in /usr/local/lib/python3.6/dist-packages (from importlib-metadata>=0.20; python_version < "3.8"->catalogue<1.1.0,>=0.0.7->spacy>=2.2.2->en_core_web_sm==2.2.5) (3.7.4.3)
Download and installation successful
You can now load the model via spacy.load('en_core_web_sm')
Linking successful
/usr/local/lib/python3.6/dist-packages/en_core_web_sm -->
/usr/local/lib/python3.6/dist-packages/spacy/data/en
You can now load the model via spacy.load('en')
GPU: True
<torch._C.Generator at 0x7f7acf579b10>
2.2 设置数据集
TEXT = data.Field(tokenize='spacy')
LABEL = data.LabelField(dtype=torch.float)
train_data, test_data = datasets.IMDB.splits(TEXT, LABEL)
# IMDB是torchtext提供的数据集
print('len of train data:', len(train_data))
print('len of test data:', len(test_data))
len of train data: 25000
len of test data: 25000
['First', 'I', 'was', 'caught', 'totally', 'off', 'guard', 'by', 'the', 'film', "'s", 'initial', 'lyricism', 'and', 'then', 'I', 'became', 'totally', 'enchanted', 'with', 'the', 'unfolding', 'story', 'and', 'engrossed', 'with', 'the', 'brilliant', 'directing', '.', 'The', 'characters', 'were', 'all', 'fully', 'developed', ',', 'not', 'bigger', '-', 'than', '-', 'life', 'but', 'just', 'like', 'the', 'people', 'we', 'live', 'among', 'anywhere', 'we', 'are', 'in', 'the', 'world', ',', 'in', 'Sweden', ',', 'in', 'Turkey', 'or', 'in', 'America', ',', 'all', 'completely', 'believable', 'human', 'beings', 'with', 'foibles', 'and', 'nobility', '.', 'Hollywood', 'could', 'learn', 'so', 'much', 'from', 'this', 'beautiful', 'film', '.', 'It', 'shows', 'that', 'there', 'is', 'no', 'need', 'to', 'go', 'into', 'every', 'little', 'detail', 'behind', 'every', 'action', 'to', 'bring', 'out', 'the', 'whole', 'theme', 'clear', 'and', 'bright', ',', 'and', 'that', 'shows', 'the', 'brilliance', 'of', 'the', 'director', '!', 'Hearfelt', 'thanks', 'to', 'Kay', 'Pollak', 'and', 'the', 'wonderful', 'cast', 'for', 'this', 'superb', 'treat', '!', '!']
pos
# word2vec, glove
TEXT.build_vocab(train_data, max_size=10000, vectors='glove.6B.100d')
LABEL.build_vocab(train_data)
batchsz = 30
device = torch.device('cuda')
train_iterator, test_iterator = data.BucketIterator.splits(
(train_data, test_data),
batch_size = batchsz,
device=device
)
.vector_cache/glove.6B.zip: 862MB [06:28, 2.22MB/s]
100%|█████████▉| 398704/400000 [00:16<00:00, 24622.99it/s]
2.3 搭建lstm网络
class RNN(nn.Module):
def __init__(self, vocab_size, embedding_dim, hidden_dim):
"""
"""
super(RNN, self).__init__()
# [0-10001] => [100] vocab_size=10002 embedding_dim=100,就是说10002个单词其中是10000个真的单词还有一个是不认识的单侧另一个是特殊符号,每个单词用长度100的向量表示
self.embedding = nn.Embedding(vocab_size, embedding_dim)
#[Embedding介绍1](https://zhuanlan.zhihu.com/p/53194407)
#[Embedding介绍2](https://www.cnblogs.com/USTC-ZCC/p/11068791.html)
# [100] => [256]
self.rnn = nn.LSTM(embedding_dim, hidden_dim, num_layers=2,
bidirectional=True, dropout=0.5)
# [双向循环神经网络bidirectional介绍](https://shenxiaohai.me/2018/10/19/pytorch-tutorial-intermediate-04/)
# [256*2] => [1]
self.fc = nn.Linear(hidden_dim*2, 1)
self.dropout = nn.Dropout(0.5)
def forward(self, x):
"""
x: [seq_len, b] vs [b, 3, 28, 28]
"""
# [seq, b, 1] => [seq, b, 100]
embedding = self.dropout(self.embedding(x))
# output: [seq, b, hid_dim*2]
# hidden/h: [num_layers*2, b, hid_dim]
# cell/c: [num_layers*2, b, hid_di]
output, (hidden, cell) = self.rnn(embedding)
# [num_layers*2, b, hid_dim] => 2 of [b, hid_dim] => [b, hid_dim*2]
hidden = torch.cat([hidden[-2], hidden[-1]], dim=1)
# [b, hid_dim*2] => [b, 1]
hidden = self.dropout(hidden)
out = self.fc(hidden)
return out
2.4 embedding和网络优化
rnn = RNN(len(TEXT.vocab), 100, 256)
# 转换成embedding的形式
pretrained_embedding = TEXT.vocab.vectors
print('pretrained_embedding:', pretrained_embedding.shape)
rnn.embedding.weight.data.copy_(pretrained_embedding)
print('embedding layer inited.')
optimizer = optim.Adam(rnn.parameters(), lr=1e-3)
criteon = nn.BCEWithLogitsLoss().to(device)
rnn.to(device)
pretrained_embedding: torch.Size([10002, 100])
embedding layer inited.
RNN(
(embedding): Embedding(10002, 100)
(rnn): LSTM(100, 256, num_layers=2, dropout=0.5, bidirectional=True)
(fc): Linear(in_features=512, out_features=1, bias=True)
(dropout): Dropout(p=0.5, inplace=False)
)
2.5 训练与测试
import numpy as np
def binary_acc(preds, y):
"""
get accuracy
"""
preds = torch.round(torch.sigmoid(preds))
correct = torch.eq(preds, y).float()
acc = correct.sum() / len(correct)
return acc
def train(rnn, iterator, optimizer, criteon):
avg_acc = []
rnn.train()
for i, batch in enumerate(iterator): # 遍历所有训练数据
# [seq, b] => [b, 1] => [b]
pred = rnn(batch.text).squeeze(1)
#
loss = criteon(pred, batch.label)
acc = binary_acc(pred, batch.label).item()
avg_acc.append(acc)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if i%100 == 0:
print(i, acc)
avg_acc = np.array(avg_acc).mean()
print('avg acc:', avg_acc)
def eval(rnn, iterator, criteon):
avg_acc = []
rnn.eval()
with torch.no_grad():
for batch in iterator:
# [b, 1] => [b]
pred = rnn(batch.text).squeeze(1)
#
loss = criteon(pred, batch.label)
acc = binary_acc(pred, batch.label).item()
avg_acc.append(acc)
avg_acc = np.array(avg_acc).mean()
print('>>test:', avg_acc)
for epoch in range(10):
eval(rnn, test_iterator, criteon)
train(rnn, train_iterator, optimizer, criteon)
>>test: 0.8730615999915903
0 0.8666667342185974
100 0.8666667342185974
200 0.9666666984558105
300 0.9333333969116211
400 0.9333333969116211
500 0.9666666984558105
600 0.9000000357627869
700 0.9666666984558105
800 1.0
avg acc: 0.9348521599952552
>>test: 0.8765388191175117
0 0.8666667342185974
100 0.9000000357627869
200 1.0
300 0.9333333969116211
400 0.9333333969116211
500 0.9333333969116211
600 0.9666666984558105
700 0.8666667342185974
800 0.9000000357627869
avg acc: 0.9394085123527536
>>test: 0.8712630401984107
0 1.0
100 0.8666667342185974
200 0.9666666984558105
300 0.9666666984558105
400 1.0
500 0.9666666984558105
600 0.9666666984558105
700 0.9666666984558105
800 0.8333333730697632
avg acc: 0.94452441853585
>>test: 0.8790967720303889
0 0.9666666984558105
100 1.0
200 0.9666666984558105
300 0.9000000357627869
400 0.9333333969116211
500 0.9000000357627869
600 0.9000000357627869
700 0.9666666984558105
800 1.0
avg acc: 0.9481215391942351
>>test: 0.8758193941996824
0 0.9333333969116211
100 0.9666666984558105
200 1.0
300 0.9333333969116211
400 0.9666666984558105
500 0.9666666984558105
600 0.9666666984558105
700 0.9666666984558105
800 0.9333333969116211
avg acc: 0.9529176998338539
>>test: 0.8762590416329656
0 0.8666667342185974
100 0.9333333969116211
200 1.0
300 0.9666666984558105
400 0.8666667342185974
500 0.9333333969116211
600 0.9666666984558105
700 1.0
800 0.8666667342185974
avg acc: 0.9550360044558271
>>test: 0.8747402563941279
0 0.9666666984558105
100 0.9333333969116211
200 0.9666666984558105
300 0.9666666984558105
400 0.9333333969116211
500 0.9333333969116211
600 0.9333333969116211
700 0.9666666984558105
800 0.9333333969116211
avg acc: 0.958473253521702
>>test: 0.8732214720843793
0 0.9333333969116211
100 0.9666666984558105
200 0.9666666984558105
300 0.9666666984558105
400 0.9666666984558105
500 0.9666666984558105
600 1.0
700 0.9666666984558105
800 0.9333333969116211
avg acc: 0.9630296058792005
>>test: 0.8703038053546878
0 0.9666666984558105
100 0.9333333969116211
200 1.0
300 1.0
400 1.0
500 0.9333333969116211
600 0.9666666984558105
700 1.0
800 0.9666666984558105
avg acc: 0.965107941155811
>>test: 0.8725819842849704
0 1.0
100 1.0
200 0.9666666984558105
300 1.0
400 1.0
500 1.0
600 0.9333333969116211
700 1.0
800 0.9000000357627869
avg acc: 0.9668265646881908
薄书的pytorch项目实战lesson49-情感分类+蹭免费GPU的更多相关文章
- Java 架构师+高并发+性能优化+Spring boot大型分布式项目实战
视频课程内容包含: 高级 Java 架构师包含:Spring boot.Spring cloud.Dubbo.Redis.ActiveMQ.Nginx.Mycat.Spring.MongoDB.Zer ...
- 深度学习--LSTM网络、使用方法、实战情感分类问题
深度学习--LSTM网络.使用方法.实战情感分类问题 1.LSTM基础 长短期记忆网络(Long Short-Term Memory,简称LSTM),是RNN的一种,为了解决RNN存在长期依赖问题而设 ...
- PGL图学习之项目实践(UniMP算法实现论文节点分类、新冠疫苗项目实战,助力疫情)[系列九]
原项目链接:https://aistudio.baidu.com/aistudio/projectdetail/5100049?contributionType=1 1.图学习技术与应用 图是一个复杂 ...
- pytorch 文本情感分类和命名实体识别NER中LSTM输出的区别
文本情感分类: 文本情感分类采用LSTM的最后一层输出 比如双层的LSTM,使用正向的最后一层和反向的最后一层进行拼接 def forward(self,input): ''' :param inpu ...
- 【腾讯Bugly干货分享】React Native项目实战总结
本文来自于腾讯bugly开发者社区,非经作者同意,请勿转载,原文地址:http://dev.qq.com/topic/577e16a7640ad7b4682c64a7 “8小时内拼工作,8小时外拼成长 ...
- 【无私分享:ASP.NET CORE 项目实战(第十三章)】Asp.net Core 使用MyCat分布式数据库,实现读写分离
目录索引 [无私分享:ASP.NET CORE 项目实战]目录索引 简介 MyCat2.0版本很快就发布了,关于MyCat的动态和一些问题,大家可以加一下MyCat的官方QQ群:106088787.我 ...
- Net Core 项目实战之权限管理系统(0)
0 前言 Net Core 项目实战之权限管理系统(0) 无中生有 0 http://www.cnblogs.com/fonour/p/5848933.html 学习的最好方法就是动手去做,这里以 ...
- angularJs项目实战!01:模块划分和目录组织
近日来我有幸主导了一个典型的web app开发.该项目从产品层次来说是个典型的CRUD应用,故而我毫不犹豫地采用了grunt + boilerplate + angularjs + bootstrap ...
- 项目实战13—企业级虚拟化Virtualization-KVM技术
项目实战系列,总架构图 http://www.cnblogs.com/along21/p/8000812.html KVM的介绍.准备工作和qemu-kvm 命令详解 1.介绍 (1)介绍 KVM:就 ...
- 【.NET Core项目实战-统一认证平台】第十章 授权篇-客户端授权
[.NET Core项目实战-统一认证平台]开篇及目录索引 上篇文章介绍了如何使用Dapper持久化IdentityServer4(以下简称ids4)的信息,并实现了sqlserver和mysql两种 ...
随机推荐
- CENTOS 6.4 编译安装APACHE PHP MYSQL ZEND【转载未测试】
CENTOS 6.4 编译安装APACHE PHP MYSQL ZEND 由 cache 发布于 WWW 2013-07-21 [ 5560 ] 次浏览 [ 0 ] 条评论 标签: LAMP 搞网站跑 ...
- Pytorch-tensor的感知机,链式法则
1.单层感知机 单层感知机的主要步骤: 1.对数据进行一个权重的累加求和,求得∑ 2.将∑经过一个激活函数Sigmoid,得出值O 3.再将O,经过一个损失函数mse_loss,得出值loss 4.根 ...
- 实训篇-JavaScript-陶渊明去没去过桃花源
<!DOCTYPE html> <html> <head> <meta charset="utf-8"> <title> ...
- ASP.NET MVC5.0 筑基到炼气大圆满一篇就搞定
一.ASP.NET MVC 过滤器 ASP.NET MVC框架支持四种不同类型的过滤器: 授权过滤器 - 实现IAuthorizationFilter属性. 动作过滤器 - 实现IActionFilt ...
- 【Oracle】列拆行/对多行数据的单行数据进行分割并多行显示
[Oracle]列拆行/对多行数据的单行数据进行分割并多行显示 参考链接:Oracle 一行字符串拆分为多行_oracle一行拆分成多行-CSDN博客 背景:要对一个表的字段的内容进行分割,分隔符都是 ...
- 力扣138(java)- 复制带随机指针的链表(中等)
题目: 给你一个长度为 n 的链表,每个节点包含一个额外增加的随机指针 random ,该指针可以指向链表中的任何节点或空节点. 构造这个链表的 深拷贝. 深拷贝应该正好由 n 个 全新 节点组成,其 ...
- DataWorks搬站方案:Azkaban作业迁移至DataWorks
简介: DataWorks迁移助手提供任务搬站功能,支持将开源调度引擎Oozie.Azkaban.Airflow的任务快速迁移至DataWorks.本文主要介绍如何将开源Azkaban工作流调度引擎中 ...
- 【详谈 Delta Lake 】系列技术专题 之 Streaming(流式计算)
简介: 本文翻译自大数据技术公司 Databricks 针对数据湖 Delta Lake 的系列技术文章.众所周知,Databricks 主导着开源大数据社区 Apache Spark.Delta ...
- 业内首款云原生技术中台产品云原生 Stack 来了!
简介: 云原生 Stack 满足了各种典型场景下客户对于线下高集成平台的诉求,让企业数字话转型不受技术约束,专注业务本身,加速企业数字化迭代. 今天,企业数字化转型依然面临很大的挑战,虽然有很多新技术 ...
- Go 调用 Java 方案和性能优化分享
简介: 一个基于 Golang 编写的日志收集和清洗的应用需要支持一些基于 JVM 的算子. 作者 | 响风 来源 | 阿里技术公众号 一 背景 一个基于 Golang 编写的日志收集和清洗的应 ...