pytorch -- CNN 文本分类 -- 《 Convolutional Neural Networks for Sentence Classification》
论文 《 Convolutional Neural Networks for Sentence Classification》通过CNN实现了文本分类。
论文地址: 666666
模型图:

模型解释可以看论文,给出code and comment:https://github.com/graykode/nlp-tutorial
# -*- coding: utf-8 -*-
# @time : 2019/11/9 13:55 import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F dtype = torch.FloatTensor # Text-CNN Parameter
embedding_size = 2 # n-gram
sequence_length = 3
num_classes = 2 # 0 or 1
filter_sizes = [2, 2, 2] # n-gram window
num_filters = 3 # 3 words sentences (=sequence_length is 3)
sentences = ["i love you", "he loves me", "she likes baseball", "i hate you", "sorry for that", "this is awful"]
labels = [1, 1, 1, 0, 0, 0] # 1 is good, 0 is not good. word_list = " ".join(sentences).split()
word_list = list(set(word_list))
word_dict = {w: i for i, w in enumerate(word_list)}
vocab_size = len(word_dict) inputs = []
for sen in sentences:
inputs.append(np.asarray([word_dict[n] for n in sen.split()])) targets = []
for out in labels:
targets.append(out) # To using Torch Softmax Loss function input_batch = Variable(torch.LongTensor(inputs))
target_batch = Variable(torch.LongTensor(targets)) class TextCNN(nn.Module):
def __init__(self):
super(TextCNN, self).__init__() self.num_filters_total = num_filters * len(filter_sizes)
self.W = nn.Parameter(torch.empty(vocab_size, embedding_size).uniform_(-1, 1)).type(dtype)
self.Weight = nn.Parameter(torch.empty(self.num_filters_total, num_classes).uniform_(-1, 1)).type(dtype)
self.Bias = nn.Parameter(0.1 * torch.ones([num_classes])).type(dtype) def forward(self, X):
embedded_chars = self.W[X] # [batch_size, sequence_length, sequence_length]
embedded_chars = embedded_chars.unsqueeze(1) # add channel(=1) [batch, channel(=1), sequence_length, embedding_size] pooled_outputs = []
for filter_size in filter_sizes:
# conv : [input_channel(=1), output_channel(=3), (filter_height, filter_width), bias_option]
conv = nn.Conv2d(1, num_filters, (filter_size, embedding_size), bias=True)(embedded_chars)
h = F.relu(conv)
# mp : ((filter_height, filter_width))
mp = nn.MaxPool2d((sequence_length - filter_size + 1, 1))
# pooled : [batch_size(=6), output_height(=1), output_width(=1), output_channel(=3)]
pooled = mp(h).permute(0, 3, 2, 1)
pooled_outputs.append(pooled) h_pool = torch.cat(pooled_outputs, len(filter_sizes)) # [batch_size(=6), output_height(=1), output_width(=1), output_channel(=3) * 3]
h_pool_flat = torch.reshape(h_pool, [-1, self.num_filters_total]) # [batch_size(=6), output_height * output_width * (output_channel * 3)] model = torch.mm(h_pool_flat, self.Weight) + self.Bias # [batch_size, num_classes]
return model model = TextCNN() criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001) # Training
for epoch in range(5000):
optimizer.zero_grad()
output = model(input_batch) # output : [batch_size, num_classes], target_batch : [batch_size] (LongTensor, not one-hot)
loss = criterion(output, target_batch)
if (epoch + 1) % 1000 == 0:
print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss)) loss.backward()
optimizer.step() # Test
test_text = 'sorry hate you'
tests = [np.asarray([word_dict[n] for n in test_text.split()])]
test_batch = Variable(torch.LongTensor(tests)) # Predict
predict = model(test_batch).data.max(1, keepdim=True)[1]
if predict[0][0] == 0:
print(test_text,"is Bad Mean...")
else:
print(test_text,"is Good Mean!!")
pytorch -- CNN 文本分类 -- 《 Convolutional Neural Networks for Sentence Classification》的更多相关文章
- 卷积神经网络用语句子分类---Convolutional Neural Networks for Sentence Classification 学习笔记
读了一篇文章,用到卷积神经网络的方法来进行文本分类,故写下一点自己的学习笔记: 本文在事先进行单词向量的学习的基础上,利用卷积神经网络(CNN)进行句子分类,然后通过微调学习任务特定的向量,提高性能. ...
- 《Convolutional Neural Networks for Sentence Classification》 文本分类
文本分类任务中可以利用CNN来提取句子中类似 n-gram 的关键信息. TextCNN的详细过程原理图见下: keras 代码: def convs_block(data, convs=[3, 3, ...
- [NLP-CNN] Convolutional Neural Networks for Sentence Classification -2014-EMNLP
1. Overview 本文将CNN用于句子分类任务 (1) 使用静态vector + CNN即可取得很好的效果:=> 这表明预训练的vector是universal的特征提取器,可以被用于多种 ...
- CNN 文本分类
谈到文本分类,就不得不谈谈CNN(Convolutional Neural Networks).这个经典的结构在文本分类中取得了不俗的结果,而运用在这里的卷积可以分为1d .2d甚至是3d的. 下面 ...
- [转] Understanding Convolutional Neural Networks for NLP
http://www.wildml.com/2015/11/understanding-convolutional-neural-networks-for-nlp/ 讲CNN以及其在NLP的应用,非常 ...
- Understanding Convolutional Neural Networks for NLP
When we hear about Convolutional Neural Network (CNNs), we typically think of Computer Vision. CNNs ...
- How to Use Convolutional Neural Networks for Time Series Classification
How to Use Convolutional Neural Networks for Time Series Classification 2019-10-08 12:09:35 This blo ...
- Deep learning_CNN_Review:A Survey of the Recent Architectures of Deep Convolutional Neural Networks——2019
CNN综述文章 的翻译 [2019 CVPR] A Survey of the Recent Architectures of Deep Convolutional Neural Networks 翻 ...
- [转]XNOR-Net ImageNet Classification Using Binary Convolutional Neural Networks
感谢: XNOR-Net ImageNet Classification Using Binary Convolutional Neural Networks XNOR-Net ImageNet Cl ...
随机推荐
- Java 中级 学习笔记 2 JVM GC 垃圾回收与算法
前言 在上一节的学习中,已经了解到了关于JVM 内存相关的内容,比如JVM 内存的划分,以及JDK8当中对于元空间的定义,最后就是字符串常量池等基本概念以及容易混淆的内容,我们都已经做过一次总结了.不 ...
- Google搜索成最大入口,简单谈下个人博客的SEO
个人静态博客SEO该考虑哪些问题呢?本篇文章给你答案 咖啡君在开始写文章时首选了微信公众号作为发布平台,但公众号在PC端的体验并不好,连最基本的文章列表都没有,所以就搭建了运维咖啡吧的网站,可以通过点 ...
- 简述http协议及抓包分析
1:HTTP请求头和响应头的格式 1:HTTP请求格式:<request-line><headers><blank line>[<request-body&g ...
- Go Web 编程之 数据库
概述 数据库用来存储数据.只要不是玩具项目,每个项目都需要用到数据库.现在用的最多的还是 MySQL,PostgreSQL的使用也在快速增长中. 在 Web 开发中,数据库也是必须的.本文将介绍如何在 ...
- 关于github显示不出来图片的问题
今天打开github,突然发现图标图片等都显示不出来了. 控制台看了一下 百度查找了Failed to load resource: net::ERR_CERT_COMMON_NAME_INVALID ...
- 使用属性创建区域 (Creating Areas with Attributes) | 使用区域 | 高级路由特性 | 精通ASP-NET-MVC-5-弗瑞曼
- 如何使用Serilog.AspNetCore记录ASP.NET Core3.0的MVC属性
这是Serilog系列的第三篇文章. 第1部分-使用Serilog RequestLogging减少日志详细程度 第2部分-使用Serilog记录所选的终结点属性 第3部分-使用Serilog.Asp ...
- 利用Python进行博客图片压缩
自己写博客的时候常常要插入一些手机拍的照片,都是几M的大小,每张手动压缩太费事了,于是根据自己博客的排版特点用Python写了一个简单的图片压缩脚本,功能是将博客图片生成缩略图,横屏的图片压缩为宽度最 ...
- CSS-06-CSS颜色表示方法
<!DOCTYPE html> <html> <head> <meta charset="UTF-8"> <title> ...
- 工具之grep
转自:http://www.cnblogs.com/dong008259/archive/2011/12/07/2279897.html grep (global search regular exp ...