pytorch -- CNN 文本分类 -- 《 Convolutional Neural Networks for Sentence Classification》
论文 《 Convolutional Neural Networks for Sentence Classification》通过CNN实现了文本分类。
论文地址: 666666
模型图:
模型解释可以看论文,给出code and comment:https://github.com/graykode/nlp-tutorial
# -*- coding: utf-8 -*-
# @time : 2019/11/9 13:55 import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F dtype = torch.FloatTensor # Text-CNN Parameter
embedding_size = 2 # n-gram
sequence_length = 3
num_classes = 2 # 0 or 1
filter_sizes = [2, 2, 2] # n-gram window
num_filters = 3 # 3 words sentences (=sequence_length is 3)
sentences = ["i love you", "he loves me", "she likes baseball", "i hate you", "sorry for that", "this is awful"]
labels = [1, 1, 1, 0, 0, 0] # 1 is good, 0 is not good. word_list = " ".join(sentences).split()
word_list = list(set(word_list))
word_dict = {w: i for i, w in enumerate(word_list)}
vocab_size = len(word_dict) inputs = []
for sen in sentences:
inputs.append(np.asarray([word_dict[n] for n in sen.split()])) targets = []
for out in labels:
targets.append(out) # To using Torch Softmax Loss function input_batch = Variable(torch.LongTensor(inputs))
target_batch = Variable(torch.LongTensor(targets)) class TextCNN(nn.Module):
def __init__(self):
super(TextCNN, self).__init__() self.num_filters_total = num_filters * len(filter_sizes)
self.W = nn.Parameter(torch.empty(vocab_size, embedding_size).uniform_(-1, 1)).type(dtype)
self.Weight = nn.Parameter(torch.empty(self.num_filters_total, num_classes).uniform_(-1, 1)).type(dtype)
self.Bias = nn.Parameter(0.1 * torch.ones([num_classes])).type(dtype) def forward(self, X):
embedded_chars = self.W[X] # [batch_size, sequence_length, sequence_length]
embedded_chars = embedded_chars.unsqueeze(1) # add channel(=1) [batch, channel(=1), sequence_length, embedding_size] pooled_outputs = []
for filter_size in filter_sizes:
# conv : [input_channel(=1), output_channel(=3), (filter_height, filter_width), bias_option]
conv = nn.Conv2d(1, num_filters, (filter_size, embedding_size), bias=True)(embedded_chars)
h = F.relu(conv)
# mp : ((filter_height, filter_width))
mp = nn.MaxPool2d((sequence_length - filter_size + 1, 1))
# pooled : [batch_size(=6), output_height(=1), output_width(=1), output_channel(=3)]
pooled = mp(h).permute(0, 3, 2, 1)
pooled_outputs.append(pooled) h_pool = torch.cat(pooled_outputs, len(filter_sizes)) # [batch_size(=6), output_height(=1), output_width(=1), output_channel(=3) * 3]
h_pool_flat = torch.reshape(h_pool, [-1, self.num_filters_total]) # [batch_size(=6), output_height * output_width * (output_channel * 3)] model = torch.mm(h_pool_flat, self.Weight) + self.Bias # [batch_size, num_classes]
return model model = TextCNN() criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001) # Training
for epoch in range(5000):
optimizer.zero_grad()
output = model(input_batch) # output : [batch_size, num_classes], target_batch : [batch_size] (LongTensor, not one-hot)
loss = criterion(output, target_batch)
if (epoch + 1) % 1000 == 0:
print('Epoch:', '%04d' % (epoch + 1), 'cost =', '{:.6f}'.format(loss)) loss.backward()
optimizer.step() # Test
test_text = 'sorry hate you'
tests = [np.asarray([word_dict[n] for n in test_text.split()])]
test_batch = Variable(torch.LongTensor(tests)) # Predict
predict = model(test_batch).data.max(1, keepdim=True)[1]
if predict[0][0] == 0:
print(test_text,"is Bad Mean...")
else:
print(test_text,"is Good Mean!!")
pytorch -- CNN 文本分类 -- 《 Convolutional Neural Networks for Sentence Classification》的更多相关文章
- 卷积神经网络用语句子分类---Convolutional Neural Networks for Sentence Classification 学习笔记
读了一篇文章,用到卷积神经网络的方法来进行文本分类,故写下一点自己的学习笔记: 本文在事先进行单词向量的学习的基础上,利用卷积神经网络(CNN)进行句子分类,然后通过微调学习任务特定的向量,提高性能. ...
- 《Convolutional Neural Networks for Sentence Classification》 文本分类
文本分类任务中可以利用CNN来提取句子中类似 n-gram 的关键信息. TextCNN的详细过程原理图见下: keras 代码: def convs_block(data, convs=[3, 3, ...
- [NLP-CNN] Convolutional Neural Networks for Sentence Classification -2014-EMNLP
1. Overview 本文将CNN用于句子分类任务 (1) 使用静态vector + CNN即可取得很好的效果:=> 这表明预训练的vector是universal的特征提取器,可以被用于多种 ...
- CNN 文本分类
谈到文本分类,就不得不谈谈CNN(Convolutional Neural Networks).这个经典的结构在文本分类中取得了不俗的结果,而运用在这里的卷积可以分为1d .2d甚至是3d的. 下面 ...
- [转] Understanding Convolutional Neural Networks for NLP
http://www.wildml.com/2015/11/understanding-convolutional-neural-networks-for-nlp/ 讲CNN以及其在NLP的应用,非常 ...
- Understanding Convolutional Neural Networks for NLP
When we hear about Convolutional Neural Network (CNNs), we typically think of Computer Vision. CNNs ...
- How to Use Convolutional Neural Networks for Time Series Classification
How to Use Convolutional Neural Networks for Time Series Classification 2019-10-08 12:09:35 This blo ...
- Deep learning_CNN_Review:A Survey of the Recent Architectures of Deep Convolutional Neural Networks——2019
CNN综述文章 的翻译 [2019 CVPR] A Survey of the Recent Architectures of Deep Convolutional Neural Networks 翻 ...
- [转]XNOR-Net ImageNet Classification Using Binary Convolutional Neural Networks
感谢: XNOR-Net ImageNet Classification Using Binary Convolutional Neural Networks XNOR-Net ImageNet Cl ...
随机推荐
- 公司没有 DBA,Mysql 运维自己来
如果你的公司有 DBA,那么我恭喜你,你可以无视 Mysql 运维.如果你的公司没有 DBA,那你就好好学两手 Mysql 基本运维操作,行走江湖,防身必备. 环境:CentOS7 版本: 一.虚拟机 ...
- Scala实践1
一.Scala安装和配置 1.1安装 Scala需要Java运行时库,安装Scala需要首先安装jdk. 然后在Scala官网下载 程序安装包 根据不同的操作系统选择不同的安装包,下载完成后,将安装包 ...
- MADP(移动应用开发平台)推动企业数字化转型
移动互联网时代,企业对于移动应用程序的需求呈现爆炸式增长,移动解决方案供应商一直致力于寻找解决方案帮助企业完成这些移动集成需求,MADP(移动应用开发平台)因此产生,MADP允许提供一种解决方案,可以 ...
- woj - 将一个问题转换为背包问题
Problem 1538 - B - Stones II Time Limit: 1000MS Memory Limit: 65536KB Total Submit: 428 Accepte ...
- Python工具类(二)—— 操作时间相关
#!/usr/bin/env python # -*- coding: utf-8 -*- """ __title__ = '操作时间的工具类' "" ...
- 投影方式- Unity3D游戏开发培训
投影方式- Unity3D游戏开发培训 作者:Jesai 2018-02-12 20:33:13 摘 要 透视投影是3D渲染的基本概念,也是3D程序设计的基础.掌握透视投影的原理对于深入理解其他 ...
- 如何构建可伸缩的Web应用?
为什么要构建可伸缩的Web应用? 想象一下,你的营销活动吸引了很多用户,在某个时候,应用必须同时为成千上万的用户提供服务,这么大的并发量,服务器的负载会很大,如果设计不当,系统将无法处理. 接下来发生 ...
- [洛谷 P5053] [COCI2017-2018#7] Clickbait
Description 下图是一个由容器和管道组成的排水系统.对于这个系统,\(Slavko\) 想知道如果一直向容器1灌水,那么所有容器从空到充满水的顺序. 系统共有 \(K\) 个容器标号为1到 ...
- 用Kolla在阿里云部署10节点高可用OpenStack
为展现 Kolla 的真正实力,我在阿里云使用 Ansible 自动创建 10 台虚机,部署一套多节点高可用 OpenStack 集群! 前言 上次 Kolla 已经表示了要打 10 个的愿望,这次我 ...
- Thumb.db看不到的问题
今天读取数据集的时候总是会读到一个Thumb.db的缩略图文件,点开查看选项里面的显示隐藏文件.文件夹处于勾选状态,此时文件夹中并不存在此文件. 解决方案: 勾选掉隐藏受保护的操作系统文件即可.