论文  《 Convolutional Neural Networks for Sentence Classification》通过CNN实现了文本分类。

论文地址: 666666

模型图:

  

模型解释可以看论文,给出code and comment:https://github.com/graykode/nlp-tutorial

 # -*- coding: utf-8 -*-
# @time : 2019/11/9 13:55 import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F dtype = torch.FloatTensor # Text-CNN Parameter
embedding_size = 2 # n-gram
sequence_length = 3
num_classes = 2 # 0 or 1
filter_sizes = [2, 2, 2] # n-gram window
num_filters = 3 # 3 words sentences (=sequence_length is 3)
sentences = ["i love you", "he loves me", "she likes baseball", "i hate you", "sorry for that", "this is awful"]
labels = [1, 1, 1, 0, 0, 0] # 1 is good, 0 is not good. word_list = " ".join(sentences).split()
word_list = list(set(word_list))
word_dict = {w: i for i, w in enumerate(word_list)}
vocab_size = len(word_dict) inputs = []
for sen in sentences:
inputs.append(np.asarray([word_dict[n] for n in sen.split()])) targets = []
for out in labels:
targets.append(out) # To using Torch Softmax Loss function input_batch = Variable(torch.LongTensor(inputs))
target_batch = Variable(torch.LongTensor(targets)) class TextCNN(nn.Module):
def __init__(self):
super(TextCNN, self).__init__() self.num_filters_total = num_filters * len(filter_sizes)
self.W = nn.Parameter(torch.empty(vocab_size, embedding_size).uniform_(-1, 1)).type(dtype)
self.Weight = nn.Parameter(torch.empty(self.num_filters_total, num_classes).uniform_(-1, 1)).type(dtype)
self.Bias = nn.Parameter(0.1 * torch.ones([num_classes])).type(dtype) def forward(self, X):
embedded_chars = self.W[X] # [batch_size, sequence_length, sequence_length]
embedded_chars = embedded_chars.unsqueeze(1) # add channel(=1) [batch, channel(=1), sequence_length, embedding_size] pooled_outputs = []
for filter_size in filter_sizes:
# conv : [input_channel(=1), output_channel(=3), (filter_height, filter_width), bias_option]
conv = nn.Conv2d(1, num_filters, (filter_size, embedding_size), bias=True)(embedded_chars)
h = F.relu(conv)
# mp : ((filter_height, filter_width))
mp = nn.MaxPool2d((sequence_length - filter_size + 1, 1))
# pooled : [batch_size(=6), output_height(=1), output_width(=1), output_channel(=3)]
pooled = mp(h).permute(0, 3, 2, 1)
pooled_outputs.append(pooled) h_pool = torch.cat(pooled_outputs, len(filter_sizes)) # [batch_size(=6), output_height(=1), output_width(=1), output_channel(=3) * 3]
h_pool_flat = torch.reshape(h_pool, [-1, self.num_filters_total]) # [batch_size(=6), output_height * output_width * (output_channel * 3)] model = torch.mm(h_pool_flat, self.Weight) + self.Bias # [batch_size, num_classes]
return model model = TextCNN() criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001) # Training
for epoch in range(5000):
optimizer.zero_grad()
output = model(input_batch) # output : [batch_size, num_classes], target_batch : [batch_size] (LongTensor, not one-hot)
loss = criterion(output, target_batch)
if (epoch + 1) % 1000 == 0:
print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss)) loss.backward()
optimizer.step() # Test
test_text = 'sorry hate you'
tests = [np.asarray([word_dict[n] for n in test_text.split()])]
test_batch = Variable(torch.LongTensor(tests)) # Predict
predict = model(test_batch).data.max(1, keepdim=True)[1]
if predict[0][0] == 0:
print(test_text,"is Bad Mean...")
else:
print(test_text,"is Good Mean!!")

pytorch -- CNN 文本分类 -- 《 Convolutional Neural Networks for Sentence Classification》的更多相关文章

  1. 卷积神经网络用语句子分类---Convolutional Neural Networks for Sentence Classification 学习笔记

    读了一篇文章,用到卷积神经网络的方法来进行文本分类,故写下一点自己的学习笔记: 本文在事先进行单词向量的学习的基础上,利用卷积神经网络(CNN)进行句子分类,然后通过微调学习任务特定的向量,提高性能. ...

  2. 《Convolutional Neural Networks for Sentence Classification》 文本分类

    文本分类任务中可以利用CNN来提取句子中类似 n-gram 的关键信息. TextCNN的详细过程原理图见下: keras 代码: def convs_block(data, convs=[3, 3, ...

  3. [NLP-CNN] Convolutional Neural Networks for Sentence Classification -2014-EMNLP

    1. Overview 本文将CNN用于句子分类任务 (1) 使用静态vector + CNN即可取得很好的效果:=> 这表明预训练的vector是universal的特征提取器,可以被用于多种 ...

  4. CNN 文本分类

    谈到文本分类,就不得不谈谈CNN(Convolutional Neural Networks).这个经典的结构在文本分类中取得了不俗的结果,而运用在这里的卷积可以分为1d .2d甚至是3d的.  下面 ...

  5. [转] Understanding Convolutional Neural Networks for NLP

    http://www.wildml.com/2015/11/understanding-convolutional-neural-networks-for-nlp/ 讲CNN以及其在NLP的应用,非常 ...

  6. Understanding Convolutional Neural Networks for NLP

    When we hear about Convolutional Neural Network (CNNs), we typically think of Computer Vision. CNNs ...

  7. How to Use Convolutional Neural Networks for Time Series Classification

    How to Use Convolutional Neural Networks for Time Series Classification 2019-10-08 12:09:35 This blo ...

  8. Deep learning_CNN_Review:A Survey of the Recent Architectures of Deep Convolutional Neural Networks——2019

    CNN综述文章 的翻译 [2019 CVPR] A Survey of the Recent Architectures of Deep Convolutional Neural Networks 翻 ...

  9. [转]XNOR-Net ImageNet Classification Using Binary Convolutional Neural Networks

    感谢: XNOR-Net ImageNet Classification Using Binary Convolutional Neural Networks XNOR-Net ImageNet Cl ...

随机推荐

  1. c++ 快读快输模板

    快读 inline int read() { ; ; char ch=getchar(); ; ch=getchar();} )+(X<<)+ch-'; ch=getchar();} if ...

  2. 求二叉树的深度,从根节点到叶子节点的最大值,以及最大路径(python代码实现)

    首先定义一个节点类,包含三个成员变量,分别是节点值,左指针,右指针,如下代码所示: class Node(object): def __init__(self, value): self.value ...

  3. BigDecimal的加减乘除,比较,小数保留

    关于BigDecimal的一些常用基本操作记录 1        BigDecimal b1 = new BigDecimal("1.124"); 2        BigDeci ...

  4. 关于爬虫的日常复习(2)—— urllib库

  5. 第一个javaWeb项目-注册界面

    基本功能: 实现信息录入到数据库和信息规范检查 题目要求: 项目目录: 网页界面: 程序源码: package Dao; import java.sql.Connection; import java ...

  6. git工作中总结

    # .克隆到本地 git clone url git clone -b 分支 url # 注意:克隆完成后,要删除.git隐藏文件夹 # .修改代码 # .生成master git init git ...

  7. 【javaWeb】HttpServletRequest常用获取URL的方法

    1.request.getRequestURL() 返回的是完整的url,包括Http协议,端口号,servlet名字和映射路径,但它不包含请求参数. 2.request.getRequestURI( ...

  8. 分层有限状态机的C++实现

    为了方便我的游戏开发,写了这么一个通用的分层有限状态机.希望在其稳定以后,可以作为一个组件加入到我的游戏引擎当中. 目前使用了std::function来调用回调函数,在未来可能会用委托机制代替. 第 ...

  9. Math.Atan2 方法

    返回正切值为两个指定数字的商的角度. public static double Atan2 ( double y, double x ) 参数 y 点的 y 坐标. x 点的 x 坐标. 返回值 角  ...

  10. potel99se 文件损坏修复

    一直使用protel99se来做电路图,非常方便快捷.最近一次打开常用的一个ddb文件,提示失败,无法打开了.protel99使用的数据库文件实际上是access97 的mdb数据库,于是修改成mdb ...