tensorflow在文本处理中的使用——词袋
代码来源于:tensorflow机器学习实战指南(曾益强 译,2017年9月)——第七章:自然语言处理
代码地址:https://github.com/nfmcclure/tensorflow-cookbook
解决问题:使用“词袋”嵌入来进行垃圾短信的预测(使用逻辑回归算法)
缺点:不考虑相关单词顺序特征,长文本的处理困难
步骤如下:
step1:导入需要的包
step2:准备数据集
step3:选择参数(每个文本保留多少单词数,最低词频是多少)
step4:构建词袋
step5:分割数据集
step6:构建图
step7:训练
step8:测试
step1:导入需要的包
import tensorflow as tf
import matplotlib.pyplot as plt
import os
import numpy as np
import csv
import string
import requests
import io
from zipfile import ZipFile
from tensorflow.contrib import learn
from tensorflow.python.framework import ops
ops.reset_default_graph() # Start a graph session
sess = tf.Session()
step2:准备数据集
# Check if data was downloaded, otherwise download it and save for future use
save_file_name = os.path.join('temp','temp_spam_data.csv')
if os.path.isfile(save_file_name):
text_data = []
with open(save_file_name, 'r') as temp_output_file:
reader = csv.reader(temp_output_file)
for row in reader:
text_data.append(row)
else:
zip_url = 'http://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip'
r = requests.get(zip_url)
z = ZipFile(io.BytesIO(r.content))
file = z.read('SMSSpamCollection')
# Format Data
text_data = file.decode()
text_data = text_data.encode('ascii',errors='ignore')
text_data = text_data.decode().split('\n')
text_data = [x.split('\t') for x in text_data if len(x)>=1] # And write to csv
with open(save_file_name, 'w') as temp_output_file:
writer = csv.writer(temp_output_file)
writer.writerows(text_data) texts = [x[1] for x in text_data]
target = [x[0] for x in text_data] # Relabel 'spam' as 1, 'ham' as 0
target = [1 if x=='spam' else 0 for x in target] # Normalize text,为减少无意义的词汇,对文本进行规则化处理
# Lower case
texts = [x.lower() for x in texts] # Remove punctuation
texts = [''.join(c for c in x if c not in string.punctuation) for x in texts] # Remove numbers
texts = [''.join(c for c in x if c not in '') for x in texts] # Trim extra whitespace
texts = [' '.join(x.split()) for x in texts]
step3:选择参数
# Plot histogram of text lengths文本数据中的单词数的直方图
text_lengths = [len(x.split()) for x in texts]
text_lengths = [x for x in text_lengths if x < 50]
plt.hist(text_lengths, bins=25)
plt.title('Histogram of # of Words in Texts')

step4:构建词袋
# Choose max text word length at 25,也可以设为30或者40
sentence_size = 25
min_word_freq = 3 # Setup vocabulary processor
vocab_processor = learn.preprocessing.VocabularyProcessor(sentence_size, min_frequency=min_word_freq) # Have to fit transform to get length of unique words.
vocab_processor.fit_transform(texts)
embedding_size = len(vocab_processor.vocabulary_)
step5:分割数据集
# Split up data set into train/test
train_indices = np.random.choice(len(texts), round(len(texts)*0.8), replace=False)
test_indices = np.array(list(set(range(len(texts))) - set(train_indices)))
texts_train = [x for ix, x in enumerate(texts) if ix in train_indices]
texts_test = [x for ix, x in enumerate(texts) if ix in test_indices]
target_train = [x for ix, x in enumerate(target) if ix in train_indices]
target_test = [x for ix, x in enumerate(target) if ix in test_indices]
step6:构建图
step6.1:构建出文本的向量
# Setup Index Matrix for one-hot-encoding,使用该矩阵为每个单词查找稀疏向量
identity_mat = tf.diag(tf.ones(shape=[embedding_size])) # Create variables for logistic regression
A = tf.Variable(tf.random_normal(shape=[embedding_size,1]))
b = tf.Variable(tf.random_normal(shape=[1,1])) # Initialize placeholders
x_data = tf.placeholder(shape=[sentence_size], dtype=tf.int32)
y_target = tf.placeholder(shape=[1, 1], dtype=tf.float32) # Text-Vocab Embedding,使用tf的嵌入查找函数来映射句子中的单词为单位矩阵的one-hot向量,再进行求和
# tf.nn.embedding_lookup(y,x)x为索引,找出y中对应索引的值
x_embed = tf.nn.embedding_lookup(identity_mat, x_data)
x_col_sums = tf.reduce_sum(x_embed, 0) # Declare model operations
x_col_sums_2D = tf.expand_dims(x_col_sums, 0)
疑问:如何利用词袋将文本变成向量?

step6.2 构建图
model_output = tf.add(tf.matmul(x_col_sums_2D, A), b) # Declare loss function (Cross Entropy loss)
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(model_output, y_target)) # Prediction operation
prediction = tf.sigmoid(model_output) # Declare optimizer
my_opt = tf.train.GradientDescentOptimizer(0.001)
train_step = my_opt.minimize(loss)
step7:训练
# Intitialize Variables
init = tf.initialize_all_variables()
sess.run(init) # Start Logistic Regression
print('Starting Training Over {} Sentences.'.format(len(texts_train)))
loss_vec = []
train_acc_all = []
train_acc_avg = []
for ix, t in enumerate(vocab_processor.fit_transform(texts_train)):
y_data = [[target_train[ix]]] sess.run(train_step, feed_dict={x_data: t, y_target: y_data})
temp_loss = sess.run(loss, feed_dict={x_data: t, y_target: y_data})
loss_vec.append(temp_loss) if (ix+1)%10==0:
print('Training Observation #' + str(ix+1) + ': Loss = ' + str(temp_loss)) # Keep trailing average of past 50 observations accuracy
# Get prediction of single observation
[[temp_pred]] = sess.run(prediction, feed_dict={x_data:t, y_target:y_data})
# Get True/False if prediction is accurate
train_acc_temp = target_train[ix]==np.round(temp_pred)
train_acc_all.append(train_acc_temp)
if len(train_acc_all) >= 50:
train_acc_avg.append(np.mean(train_acc_all[-50:]))
step8:测试
# Get test set accuracy
print('Getting Test Set Accuracy For {} Sentences.'.format(len(texts_test)))
test_acc_all = []
for ix, t in enumerate(vocab_processor.fit_transform(texts_test)):
y_data = [[target_test[ix]]] if (ix+1)%50==0:
print('Test Observation #' + str(ix+1)) # Keep trailing average of past 50 observations accuracy
# Get prediction of single observation
[[temp_pred]] = sess.run(prediction, feed_dict={x_data:t, y_target:y_data})
# Get True/False if prediction is accurate
test_acc_temp = target_test[ix]==np.round(temp_pred)
test_acc_all.append(test_acc_temp) print('\nOverall Test Accuracy: {}'.format(np.mean(test_acc_all))) # Plot training accuracy over time
plt.plot(range(len(train_acc_avg)), train_acc_avg, 'k-', label='Train Accuracy')
plt.title('Avg Training Acc Over Past 50 Generations')
plt.xlabel('Generation')
plt.ylabel('Training Accuracy')
plt.show()
tensorflow在文本处理中的使用——词袋的更多相关文章
- tensorflow在文本处理中的使用——TF-IDF算法
代码来源于:tensorflow机器学习实战指南(曾益强 译,2017年9月)——第七章:自然语言处理 代码地址:https://github.com/nfmcclure/tensorflow-coo ...
- tensorflow在文本处理中的使用——CBOW词嵌入模型
代码来源于:tensorflow机器学习实战指南(曾益强 译,2017年9月)——第七章:自然语言处理 代码地址:https://github.com/nfmcclure/tensorflow-coo ...
- tensorflow在文本处理中的使用——Doc2Vec情感分析
代码来源于:tensorflow机器学习实战指南(曾益强 译,2017年9月)——第七章:自然语言处理 代码地址:https://github.com/nfmcclure/tensorflow-coo ...
- tensorflow在文本处理中的使用——skip-gram模型
代码来源于:tensorflow机器学习实战指南(曾益强 译,2017年9月)——第七章:自然语言处理 代码地址:https://github.com/nfmcclure/tensorflow-coo ...
- tensorflow在文本处理中的使用——Word2Vec预测
代码来源于:tensorflow机器学习实战指南(曾益强 译,2017年9月)——第七章:自然语言处理 代码地址:https://github.com/nfmcclure/tensorflow-coo ...
- tensorflow在文本处理中的使用——skip-gram & CBOW原理总结
摘自:http://www.cnblogs.com/pinard/p/7160330.html 先看下列三篇,再理解此篇会更容易些(个人意见) skip-gram,CBOW,Word2Vec 词向量基 ...
- tensorflow在文本处理中的使用——辅助函数
代码来源于:tensorflow机器学习实战指南(曾益强 译,2017年9月)——第七章:自然语言处理 代码地址:https://github.com/nfmcclure/tensorflow-coo ...
- (数据科学学习手札71)在Python中制作个性化词云图
本文对应脚本及数据已上传至我的Github仓库https://github.com/CNFeffery/DataScienceStudyNotes 一.简介 词云图是文本挖掘中用来表征词频的数据可视化 ...
- TensorFlow实现文本情感分析详解
http://c.biancheng.net/view/1938.html 前面我们介绍了如何将卷积网络应用于图像.本节将把相似的想法应用于文本. 文本和图像有什么共同之处?乍一看很少.但是,如果将句 ...
随机推荐
- 实现手机网页调起原生微信朋友圈分享的工具nativeShare.js
http://www.liaoxiansheng.cn/?p=294 我们知道现在我们无法直接通过js直接跳转到微信和QQ等软件进行分享,但是现在像UC浏览器和QQ浏览器这样的主流浏览器自带一个分享工 ...
- [已转移]js事件流之事件冒泡的应用----事件委托
该文章已转移到博客:https://cynthia0329.github.io/ 什么是事件委托? 它还有一个名字叫事件代理. JavaScript高级程序设计上讲: 事件委托就是利用事件冒泡,只指定 ...
- 2019-9-2-dotnet-命名管道名字长度限制
title author date CreateTime categories dotnet 命名管道名字长度限制 lindexi 2019-09-02 11:54:50 +0800 2019-09- ...
- firefox扩展开发(一) : 扩展的基本结构
用过firefox的人肯定要安装firefox的扩展,这样才能发挥火狐的全部实力.一般扩展是一个后缀为.xpi的文件,其实这个文件就是zip格式的压缩包,压缩了一个扩展所需要的所有目录和文件,基本的目 ...
- 2019-6-23-WPF-网络-request-的-read-方法不会返回
title author date CreateTime categories WPF 网络 request 的 read 方法不会返回 lindexi 2019-06-23 11:26:26 +08 ...
- [mysql]MySQL Daemon failed to start 2016-08-14 21:27 1121人阅读 评论(18) 收藏
前两天我们发现发布好的网站不可以进行注册,登陆这些活动,但是访问页面是正常的.于是开始对问题进行排查,首先我们重启了jenkins,但是每次重启都有错误,于是我们只能重启服务器,重启服务器需要重新启动 ...
- 《C程序设计语言》笔记(一)
一:导言 1:printf中的格式化字符串: %ld 按照long整型打印 %6d 按照十进制整数打印,至少6个字符宽,不够的 ...
- Android中解析Json数据
在开发中常常会遇到解析json的问题 在这里总结几种解析的方式: 方式一: json数据: private String jsonData = "[{\"name\":\ ...
- form表单提交,后台怎么获取select的值?后台直接获取即可,和input方式一样。
form表单提交,后台怎么获取select的值? 后台直接获取即可,和后台获取input的值方式一样. form提交后,后台直接根据select的name获取即可,即getPara("XXX ...
- oracle函数 TO_DATE(X[,c2[,c3]])
[功能]将字符串X转化为日期型 [参数]c2,c3,字符型,参照to_char() [返回]字符串 如果x格式为日期型(date)格式时,则相同表达:date x 如果x格式为日期时间型(timest ...