论文  《 Convolutional Neural Networks for Sentence Classification》通过CNN实现了文本分类。

论文地址: 666666

模型图:

  

模型解释可以看论文,给出code and comment:https://github.com/graykode/nlp-tutorial

 # -*- coding: utf-8 -*-
# @time : 2019/11/9 13:55 import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F dtype = torch.FloatTensor # Text-CNN Parameter
embedding_size = 2 # n-gram
sequence_length = 3
num_classes = 2 # 0 or 1
filter_sizes = [2, 2, 2] # n-gram window
num_filters = 3 # 3 words sentences (=sequence_length is 3)
sentences = ["i love you", "he loves me", "she likes baseball", "i hate you", "sorry for that", "this is awful"]
labels = [1, 1, 1, 0, 0, 0] # 1 is good, 0 is not good. word_list = " ".join(sentences).split()
word_list = list(set(word_list))
word_dict = {w: i for i, w in enumerate(word_list)}
vocab_size = len(word_dict) inputs = []
for sen in sentences:
inputs.append(np.asarray([word_dict[n] for n in sen.split()])) targets = []
for out in labels:
targets.append(out) # To using Torch Softmax Loss function input_batch = Variable(torch.LongTensor(inputs))
target_batch = Variable(torch.LongTensor(targets)) class TextCNN(nn.Module):
def __init__(self):
super(TextCNN, self).__init__() self.num_filters_total = num_filters * len(filter_sizes)
self.W = nn.Parameter(torch.empty(vocab_size, embedding_size).uniform_(-1, 1)).type(dtype)
self.Weight = nn.Parameter(torch.empty(self.num_filters_total, num_classes).uniform_(-1, 1)).type(dtype)
self.Bias = nn.Parameter(0.1 * torch.ones([num_classes])).type(dtype) def forward(self, X):
embedded_chars = self.W[X] # [batch_size, sequence_length, sequence_length]
embedded_chars = embedded_chars.unsqueeze(1) # add channel(=1) [batch, channel(=1), sequence_length, embedding_size] pooled_outputs = []
for filter_size in filter_sizes:
# conv : [input_channel(=1), output_channel(=3), (filter_height, filter_width), bias_option]
conv = nn.Conv2d(1, num_filters, (filter_size, embedding_size), bias=True)(embedded_chars)
h = F.relu(conv)
# mp : ((filter_height, filter_width))
mp = nn.MaxPool2d((sequence_length - filter_size + 1, 1))
# pooled : [batch_size(=6), output_height(=1), output_width(=1), output_channel(=3)]
pooled = mp(h).permute(0, 3, 2, 1)
pooled_outputs.append(pooled) h_pool = torch.cat(pooled_outputs, len(filter_sizes)) # [batch_size(=6), output_height(=1), output_width(=1), output_channel(=3) * 3]
h_pool_flat = torch.reshape(h_pool, [-1, self.num_filters_total]) # [batch_size(=6), output_height * output_width * (output_channel * 3)] model = torch.mm(h_pool_flat, self.Weight) + self.Bias # [batch_size, num_classes]
return model model = TextCNN() criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001) # Training
for epoch in range(5000):
optimizer.zero_grad()
output = model(input_batch) # output : [batch_size, num_classes], target_batch : [batch_size] (LongTensor, not one-hot)
loss = criterion(output, target_batch)
if (epoch + 1) % 1000 == 0:
print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss)) loss.backward()
optimizer.step() # Test
test_text = 'sorry hate you'
tests = [np.asarray([word_dict[n] for n in test_text.split()])]
test_batch = Variable(torch.LongTensor(tests)) # Predict
predict = model(test_batch).data.max(1, keepdim=True)[1]
if predict[0][0] == 0:
print(test_text,"is Bad Mean...")
else:
print(test_text,"is Good Mean!!")

pytorch -- CNN 文本分类 -- 《 Convolutional Neural Networks for Sentence Classification》的更多相关文章

  1. 卷积神经网络用语句子分类---Convolutional Neural Networks for Sentence Classification 学习笔记

    读了一篇文章,用到卷积神经网络的方法来进行文本分类,故写下一点自己的学习笔记: 本文在事先进行单词向量的学习的基础上,利用卷积神经网络(CNN)进行句子分类,然后通过微调学习任务特定的向量,提高性能. ...

  2. 《Convolutional Neural Networks for Sentence Classification》 文本分类

    文本分类任务中可以利用CNN来提取句子中类似 n-gram 的关键信息. TextCNN的详细过程原理图见下: keras 代码: def convs_block(data, convs=[3, 3, ...

  3. [NLP-CNN] Convolutional Neural Networks for Sentence Classification -2014-EMNLP

    1. Overview 本文将CNN用于句子分类任务 (1) 使用静态vector + CNN即可取得很好的效果:=> 这表明预训练的vector是universal的特征提取器,可以被用于多种 ...

  4. CNN 文本分类

    谈到文本分类,就不得不谈谈CNN(Convolutional Neural Networks).这个经典的结构在文本分类中取得了不俗的结果,而运用在这里的卷积可以分为1d .2d甚至是3d的.  下面 ...

  5. [转] Understanding Convolutional Neural Networks for NLP

    http://www.wildml.com/2015/11/understanding-convolutional-neural-networks-for-nlp/ 讲CNN以及其在NLP的应用,非常 ...

  6. Understanding Convolutional Neural Networks for NLP

    When we hear about Convolutional Neural Network (CNNs), we typically think of Computer Vision. CNNs ...

  7. How to Use Convolutional Neural Networks for Time Series Classification

    How to Use Convolutional Neural Networks for Time Series Classification 2019-10-08 12:09:35 This blo ...

  8. Deep learning_CNN_Review:A Survey of the Recent Architectures of Deep Convolutional Neural Networks——2019

    CNN综述文章 的翻译 [2019 CVPR] A Survey of the Recent Architectures of Deep Convolutional Neural Networks 翻 ...

  9. [转]XNOR-Net ImageNet Classification Using Binary Convolutional Neural Networks

    感谢: XNOR-Net ImageNet Classification Using Binary Convolutional Neural Networks XNOR-Net ImageNet Cl ...

随机推荐

  1. SEATA 分布式事务入门DEMO

    Simple Extensible Autonomous Transacation Architecture,seata是简单的.可扩展.自主性高的分布式架构 SEATA Server Configu ...

  2. cogs 1963. [HAOI 2015] 树上操作 树链剖分+线段树

    1963. [HAOI 2015] 树上操作 ★★★☆   输入文件:haoi2015_t2.in   输出文件:haoi2015_t2.out   简单对比时间限制:1 s   内存限制:256 M ...

  3. Python第一个请求接口

    1.普通get请求 import requests import json login_res=requests.post(url='http://joy.web.com:8090/login',da ...

  4. Nginx代理服务——常用的配置语法

    可以到官方查看所有代理的配置语法http://nginx.org/en/docs/http/ngx_http_proxy_module.html 缓存区 Syntax:proxy_buffering ...

  5. 利用 sklearn 生成交叉特征:

    ------------------------------------- ------------------------------------- ------------------------ ...

  6. spring cloud的配置

    注解篇 @EnableEurekaServer 注解启动一个服务注册中心提供给其他应用进行对话 @EnableDiscoveryClient 激活Eureka中的DiscoveryClient实现 配 ...

  7. 使用Oracle Stream Analytics 21步搭建大数据实时流分析平台

    概要: Oracle Stream Analytics(OSA)是企业级大数据流实时分析计算平台.它可以通过使用复杂的关联模式,扩充和机器学习算法来自动处理和分析大规模实时信息.流式传输的大数据可以源 ...

  8. Java电商支付系统手把手实现(二) - 数据库表设计的最佳实践

    1 数据库设计 1.1 表关系梳理 仔细思考业务关系,得到如下表关系图 1.2 用户表结构 1.3 分类表结构 id=0为根节点,分类其实是树状结构 1.4 商品表结构 注意价格字段的类型为 deci ...

  9. [bzoj2004] [洛谷P3204] [Hnoi2010] Bus 公交线路

    Description 小Z所在的城市有N个公交车站,排列在一条长(N-1)km的直线上,从左到右依次编号为1到N,相邻公交车站间的距 离均为1km. 作为公交车线路的规划者,小Z调查了市民的需求,决 ...

  10. 红黑树(依照4阶B树C++实现)

    我在编写红黑树的时候类比这2-3-4树的原理来书写 语言标准:C++11 在Ubuntu 18.04上通过编译和测试 从刚开始只听说过这个概念,到学习,再到编出代码,然后在进行测试,最后完成代码一共花 ...