pytorch -- CNN 文本分类 -- 《 Convolutional Neural Networks for Sentence Classification》
论文 《 Convolutional Neural Networks for Sentence Classification》通过CNN实现了文本分类。
论文地址: 666666
模型图:

模型解释可以看论文,给出code and comment:https://github.com/graykode/nlp-tutorial
# -*- coding: utf-8 -*-
# @time : 2019/11/9 13:55 import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F dtype = torch.FloatTensor # Text-CNN Parameter
embedding_size = 2 # n-gram
sequence_length = 3
num_classes = 2 # 0 or 1
filter_sizes = [2, 2, 2] # n-gram window
num_filters = 3 # 3 words sentences (=sequence_length is 3)
sentences = ["i love you", "he loves me", "she likes baseball", "i hate you", "sorry for that", "this is awful"]
labels = [1, 1, 1, 0, 0, 0] # 1 is good, 0 is not good. word_list = " ".join(sentences).split()
word_list = list(set(word_list))
word_dict = {w: i for i, w in enumerate(word_list)}
vocab_size = len(word_dict) inputs = []
for sen in sentences:
inputs.append(np.asarray([word_dict[n] for n in sen.split()])) targets = []
for out in labels:
targets.append(out) # To using Torch Softmax Loss function input_batch = Variable(torch.LongTensor(inputs))
target_batch = Variable(torch.LongTensor(targets)) class TextCNN(nn.Module):
def __init__(self):
super(TextCNN, self).__init__() self.num_filters_total = num_filters * len(filter_sizes)
self.W = nn.Parameter(torch.empty(vocab_size, embedding_size).uniform_(-1, 1)).type(dtype)
self.Weight = nn.Parameter(torch.empty(self.num_filters_total, num_classes).uniform_(-1, 1)).type(dtype)
self.Bias = nn.Parameter(0.1 * torch.ones([num_classes])).type(dtype) def forward(self, X):
embedded_chars = self.W[X] # [batch_size, sequence_length, sequence_length]
embedded_chars = embedded_chars.unsqueeze(1) # add channel(=1) [batch, channel(=1), sequence_length, embedding_size] pooled_outputs = []
for filter_size in filter_sizes:
# conv : [input_channel(=1), output_channel(=3), (filter_height, filter_width), bias_option]
conv = nn.Conv2d(1, num_filters, (filter_size, embedding_size), bias=True)(embedded_chars)
h = F.relu(conv)
# mp : ((filter_height, filter_width))
mp = nn.MaxPool2d((sequence_length - filter_size + 1, 1))
# pooled : [batch_size(=6), output_height(=1), output_width(=1), output_channel(=3)]
pooled = mp(h).permute(0, 3, 2, 1)
pooled_outputs.append(pooled) h_pool = torch.cat(pooled_outputs, len(filter_sizes)) # [batch_size(=6), output_height(=1), output_width(=1), output_channel(=3) * 3]
h_pool_flat = torch.reshape(h_pool, [-1, self.num_filters_total]) # [batch_size(=6), output_height * output_width * (output_channel * 3)] model = torch.mm(h_pool_flat, self.Weight) + self.Bias # [batch_size, num_classes]
return model model = TextCNN() criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001) # Training
for epoch in range(5000):
optimizer.zero_grad()
output = model(input_batch) # output : [batch_size, num_classes], target_batch : [batch_size] (LongTensor, not one-hot)
loss = criterion(output, target_batch)
if (epoch + 1) % 1000 == 0:
print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss)) loss.backward()
optimizer.step() # Test
test_text = 'sorry hate you'
tests = [np.asarray([word_dict[n] for n in test_text.split()])]
test_batch = Variable(torch.LongTensor(tests)) # Predict
predict = model(test_batch).data.max(1, keepdim=True)[1]
if predict[0][0] == 0:
print(test_text,"is Bad Mean...")
else:
print(test_text,"is Good Mean!!")
pytorch -- CNN 文本分类 -- 《 Convolutional Neural Networks for Sentence Classification》的更多相关文章
- 卷积神经网络用语句子分类---Convolutional Neural Networks for Sentence Classification 学习笔记
读了一篇文章,用到卷积神经网络的方法来进行文本分类,故写下一点自己的学习笔记: 本文在事先进行单词向量的学习的基础上,利用卷积神经网络(CNN)进行句子分类,然后通过微调学习任务特定的向量,提高性能. ...
- 《Convolutional Neural Networks for Sentence Classification》 文本分类
文本分类任务中可以利用CNN来提取句子中类似 n-gram 的关键信息. TextCNN的详细过程原理图见下: keras 代码: def convs_block(data, convs=[3, 3, ...
- [NLP-CNN] Convolutional Neural Networks for Sentence Classification -2014-EMNLP
1. Overview 本文将CNN用于句子分类任务 (1) 使用静态vector + CNN即可取得很好的效果:=> 这表明预训练的vector是universal的特征提取器,可以被用于多种 ...
- CNN 文本分类
谈到文本分类,就不得不谈谈CNN(Convolutional Neural Networks).这个经典的结构在文本分类中取得了不俗的结果,而运用在这里的卷积可以分为1d .2d甚至是3d的. 下面 ...
- [转] Understanding Convolutional Neural Networks for NLP
http://www.wildml.com/2015/11/understanding-convolutional-neural-networks-for-nlp/ 讲CNN以及其在NLP的应用,非常 ...
- Understanding Convolutional Neural Networks for NLP
When we hear about Convolutional Neural Network (CNNs), we typically think of Computer Vision. CNNs ...
- How to Use Convolutional Neural Networks for Time Series Classification
How to Use Convolutional Neural Networks for Time Series Classification 2019-10-08 12:09:35 This blo ...
- Deep learning_CNN_Review:A Survey of the Recent Architectures of Deep Convolutional Neural Networks——2019
CNN综述文章 的翻译 [2019 CVPR] A Survey of the Recent Architectures of Deep Convolutional Neural Networks 翻 ...
- [转]XNOR-Net ImageNet Classification Using Binary Convolutional Neural Networks
感谢: XNOR-Net ImageNet Classification Using Binary Convolutional Neural Networks XNOR-Net ImageNet Cl ...
随机推荐
- Scala与Mongodb实践2-----图片、日期的存储读取
目的:在IDEA中实现图片.日期等相关的类型在mongodb存储读取 主要是Scala和mongodb里面的类型的转换.Scala里面的数据编码类型和mongodb里面的存储的数据类型各个不同.存在类 ...
- ASP.Net MVC 引用动态 js 脚本
希望可以动态生成 js 发送给客户端使用. layout页引用: <script type="text/javascript" src="@Url.Action( ...
- 贪心 park
来总结一道非常经典的好题 这一道题是通过贪心实现的 首先看到这一题的时间复杂度 n<=100000 需要一个比较玄学的做法 我们先假设把题干改成这个样子 一圈n个车位 停在每个车位都有一定的代价 ...
- CF 558 C
Amr loves Chemistry, and specially doing experiments. He is preparing for a new interesting experime ...
- C# html生成图片保存下载
最近有个需求,需要把内容生成图片,我找到一些资料可以将html页面生成图片并保存下载 下面是简单的实现 1.html页面 @{ Layout = null; } <!DOCTYPE html&g ...
- java.lang.UnsupportedOperationException: Manual close is not allowed over a Spring managed SqlSession
java.lang.UnsupportedOperationException: Manual close is not allowed over a Spring managed SqlSessio ...
- python3小脚本-监控服务器性能并插入mysql数据库
操作系统: centos版本 7.4 防火墙 关闭 selinux 关闭 python版本 3.6 mysql版本 5.7 #操作系统性能脚本 [root@localhost sql]# cat cp ...
- 百度搜索关键词联想API JSONP使用实例
许多搜索引擎都提供了关键词联想api,且大多数都是jsonp. Jsonp(JSON with Padding) 是 json 的一种"使用模式",可以让网页从别的域名(网站)那获 ...
- 移除sitemap中的entity
下面截图是sitemap所在的位置 如果遇到什么原因,当前使用的entity被弃用需要删除,必须要把当前site map 引用的entity也一并删除. 不然会导致site map不能正常加载
- 用CSS实现横向滚动条
在进行app制作时,需要使用横向滚动条是内容展示更完善 首先,这是html代码: 这是CSS代码: 要点: 设置显示内容的宽 white-space是防止内容自动折行 overflow-y设置为hid ...