pytorch -- CNN 文本分类 -- 《 Convolutional Neural Networks for Sentence Classification》
论文 《 Convolutional Neural Networks for Sentence Classification》通过CNN实现了文本分类。
论文地址: 666666
模型图:

模型解释可以看论文,给出code and comment:https://github.com/graykode/nlp-tutorial
# -*- coding: utf-8 -*-
# @time : 2019/11/9 13:55 import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F dtype = torch.FloatTensor # Text-CNN Parameter
embedding_size = 2 # n-gram
sequence_length = 3
num_classes = 2 # 0 or 1
filter_sizes = [2, 2, 2] # n-gram window
num_filters = 3 # 3 words sentences (=sequence_length is 3)
sentences = ["i love you", "he loves me", "she likes baseball", "i hate you", "sorry for that", "this is awful"]
labels = [1, 1, 1, 0, 0, 0] # 1 is good, 0 is not good. word_list = " ".join(sentences).split()
word_list = list(set(word_list))
word_dict = {w: i for i, w in enumerate(word_list)}
vocab_size = len(word_dict) inputs = []
for sen in sentences:
inputs.append(np.asarray([word_dict[n] for n in sen.split()])) targets = []
for out in labels:
targets.append(out) # To using Torch Softmax Loss function input_batch = Variable(torch.LongTensor(inputs))
target_batch = Variable(torch.LongTensor(targets)) class TextCNN(nn.Module):
def __init__(self):
super(TextCNN, self).__init__() self.num_filters_total = num_filters * len(filter_sizes)
self.W = nn.Parameter(torch.empty(vocab_size, embedding_size).uniform_(-1, 1)).type(dtype)
self.Weight = nn.Parameter(torch.empty(self.num_filters_total, num_classes).uniform_(-1, 1)).type(dtype)
self.Bias = nn.Parameter(0.1 * torch.ones([num_classes])).type(dtype) def forward(self, X):
embedded_chars = self.W[X] # [batch_size, sequence_length, sequence_length]
embedded_chars = embedded_chars.unsqueeze(1) # add channel(=1) [batch, channel(=1), sequence_length, embedding_size] pooled_outputs = []
for filter_size in filter_sizes:
# conv : [input_channel(=1), output_channel(=3), (filter_height, filter_width), bias_option]
conv = nn.Conv2d(1, num_filters, (filter_size, embedding_size), bias=True)(embedded_chars)
h = F.relu(conv)
# mp : ((filter_height, filter_width))
mp = nn.MaxPool2d((sequence_length - filter_size + 1, 1))
# pooled : [batch_size(=6), output_height(=1), output_width(=1), output_channel(=3)]
pooled = mp(h).permute(0, 3, 2, 1)
pooled_outputs.append(pooled) h_pool = torch.cat(pooled_outputs, len(filter_sizes)) # [batch_size(=6), output_height(=1), output_width(=1), output_channel(=3) * 3]
h_pool_flat = torch.reshape(h_pool, [-1, self.num_filters_total]) # [batch_size(=6), output_height * output_width * (output_channel * 3)] model = torch.mm(h_pool_flat, self.Weight) + self.Bias # [batch_size, num_classes]
return model model = TextCNN() criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001) # Training
for epoch in range(5000):
optimizer.zero_grad()
output = model(input_batch) # output : [batch_size, num_classes], target_batch : [batch_size] (LongTensor, not one-hot)
loss = criterion(output, target_batch)
if (epoch + 1) % 1000 == 0:
print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss)) loss.backward()
optimizer.step() # Test
test_text = 'sorry hate you'
tests = [np.asarray([word_dict[n] for n in test_text.split()])]
test_batch = Variable(torch.LongTensor(tests)) # Predict
predict = model(test_batch).data.max(1, keepdim=True)[1]
if predict[0][0] == 0:
print(test_text,"is Bad Mean...")
else:
print(test_text,"is Good Mean!!")
pytorch -- CNN 文本分类 -- 《 Convolutional Neural Networks for Sentence Classification》的更多相关文章
- 卷积神经网络用语句子分类---Convolutional Neural Networks for Sentence Classification 学习笔记
读了一篇文章,用到卷积神经网络的方法来进行文本分类,故写下一点自己的学习笔记: 本文在事先进行单词向量的学习的基础上,利用卷积神经网络(CNN)进行句子分类,然后通过微调学习任务特定的向量,提高性能. ...
- 《Convolutional Neural Networks for Sentence Classification》 文本分类
文本分类任务中可以利用CNN来提取句子中类似 n-gram 的关键信息. TextCNN的详细过程原理图见下: keras 代码: def convs_block(data, convs=[3, 3, ...
- [NLP-CNN] Convolutional Neural Networks for Sentence Classification -2014-EMNLP
1. Overview 本文将CNN用于句子分类任务 (1) 使用静态vector + CNN即可取得很好的效果:=> 这表明预训练的vector是universal的特征提取器,可以被用于多种 ...
- CNN 文本分类
谈到文本分类,就不得不谈谈CNN(Convolutional Neural Networks).这个经典的结构在文本分类中取得了不俗的结果,而运用在这里的卷积可以分为1d .2d甚至是3d的. 下面 ...
- [转] Understanding Convolutional Neural Networks for NLP
http://www.wildml.com/2015/11/understanding-convolutional-neural-networks-for-nlp/ 讲CNN以及其在NLP的应用,非常 ...
- Understanding Convolutional Neural Networks for NLP
When we hear about Convolutional Neural Network (CNNs), we typically think of Computer Vision. CNNs ...
- How to Use Convolutional Neural Networks for Time Series Classification
How to Use Convolutional Neural Networks for Time Series Classification 2019-10-08 12:09:35 This blo ...
- Deep learning_CNN_Review:A Survey of the Recent Architectures of Deep Convolutional Neural Networks——2019
CNN综述文章 的翻译 [2019 CVPR] A Survey of the Recent Architectures of Deep Convolutional Neural Networks 翻 ...
- [转]XNOR-Net ImageNet Classification Using Binary Convolutional Neural Networks
感谢: XNOR-Net ImageNet Classification Using Binary Convolutional Neural Networks XNOR-Net ImageNet Cl ...
随机推荐
- 用Python爬取了考研吧1000条帖子,原来他们都在讨论这些!
写在前面 考研在即,想多了解考研er的想法,就是去找学长学姐或者去网上搜索,贴吧就是一个好地方.而借助强大的工具可以快速从网络鱼龙混杂的信息中得到有价值的信息.虽然网上有很多爬取百度贴吧的教程和例子, ...
- es5和es6中查找数组中的元素
let array = [1,2,3,4,5] //es5 let find = array.filter(function (item){ return item %2 === 0//返回满足条件的 ...
- javalite 使用druid数据库连接池配置
在pom文件中引入jar包 <dependency> <groupId>com.alibaba</groupId> <artifactId>druid& ...
- redis订阅发布简单实现
适用场景 业务流程遇到大量异步操作,并且业务不是很复杂 业务的健壮型要求不高 对即时场景要求不高 原理介绍 redis官网文档:https://redis.io/topics/notification ...
- 通过例子进阶学习C++(七)CMake项目通过模板库实现约瑟夫环
本文是通过例子学习C++的第七篇,通过这个例子可以快速入门c++相关的语法. 1.问题描述 回顾一下约瑟夫环问题:n 个人围坐在一个圆桌周围,现在从第 s 个人开始报数,数到第 m 个人,让他出局:然 ...
- 测试必备之Java知识(二)—— Java高级的东西
Java高级 类加载过程 加载(创建class对象) -> 连接(验证-准备-解析) -> 类初始化 类加载器类别 根类加载器:加载java核心类 扩展类加载器:加载JRE目录中的jar包 ...
- MacOSX 安装 TensorFlow
TensorFlow是一个端到端开源机器学习平台.它拥有一个包含各种工具.库和社区资源的全面灵活生态系统,可以让研究人员推动机器学习领域的先进技术的. 准备 安装 Anaconda TensorFlo ...
- [bzoj1375] [Baltic2002] Bicriterial routing 双调路径
Description 如今的道路收费发展很快.道路的密度越来越大,因此选择最佳路径是很现实的问题.城市的道路是双向的,每条道路有固定的旅行时间以及需要支付的费用. 路径是连续经过的道路组成的.总时间 ...
- 大叔 Frameworks.Entity.Core 2 PageList
Frameworks.Entity.Core\Commons\PageList\ 1 分页通用类 ListPageList<T> 继承 PageListBase<T>, IP ...
- vuex 基本语法
VUEX 的核心概念 1 .State (常用):2.Getters :3.Mutations(常用):4.Actions :5.Modules: 1.State是唯一的数据源,单一的状态树 cons ...