from keras.datasets import imdb
from keras.utils.np_utils import to_categorical
import numpy as np
from keras import models
from keras import layers
import matplotlib.pyplot as plt
#one-hot编码
def vectorize_sequences(sequences,dimension = 10000):
results = np.zeros((len(sequences),dimension))
for i,sequence in enumerate(sequences):
results[i,sequence] = 1
return results
#imdb是一个二分类问题
#一共有5w条数据,2.5w用于训练,2.5w用于测试
#每条数据是一个list,list里保存的是英文单词对应的排序
#num_words=10000表示保留前1w个常出现的单词
(x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=10000)
#下面的代码用来解码第一条数据的内容
data = x_train[0]
#word_index是一个dict,保存的是英文单词:单词排序位置
word_index = imdb.get_word_index()
index_word = dict((index,word) for (word,index) in word_index.items())
#i-3是because 0, 1 and 2 are reserved indices for "padding", "start of sequence", and "unknown".
data = ''.join(index_word.get(i-3,'?') for i in data)
######################################################
#神经网络的输入得是一个张量,使用one-hot编码处理数据
x_train = vectorize_sequences(x_train)
x_test = vectorize_sequences(x_test)
#keras的输入数据要转换为float类型,y是int类型,做一个类型转换 #构建神经网络
network = models.Sequential()
network.add(layers.Dense(16,activation='relu'))
network.add(layers.Dense(16,activation='relu'))
network.add(layers.Dense(1,activation='sigmoid')) #选择优化器、损失函数、评估准则
network.compile('rmsprop',loss='binary_crossentropy',metrics=['accuracy']) #训练模型
history = network.fit(x_train,y_train,epochs=5,batch_size=512,validation_split=0.2) history_dict = history.history
loss = history_dict['loss']
val_loss = history_dict['val_loss']
acc = history_dict['acc']
val_acc = history_dict['val_acc'] epochs = range(1,6)
#loss的图
plt.subplot(121)
plt.plot(epochs,loss,'g',label = 'Training loss')
plt.plot(epochs,val_loss,'b',label = 'Validation loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
#显示图例
plt.legend() plt.subplot(122)
plt.plot(epochs,acc,'g',label = 'Training accuracy')
plt.plot(epochs,val_acc,'b',label = 'Validation accuracy')
plt.xlabel('Epochs')
plt.ylabel('accuracy')
plt.legend()
plt.show() pre = network.predict(x_test)
print(pre)
print(y_test)

IMDB-二分类问题的更多相关文章

  1. 基于Keras的imdb数据集电影评论情感二分类

    IMDB数据集下载速度慢,可以在我的repo库中找到下载,下载后放到~/.keras/datasets/目录下,即可正常运行.)中找到下载,下载后放到~/.keras/datasets/目录下,即可正 ...

  2. 电影评论分类:二分类问题(IMDB数据集)

    IMDB数据集是Keras内部集成的,初次导入需要下载一下,之后就可以直接用了. IMDB数据集包含来自互联网的50000条严重两极分化的评论,该数据被分为用于训练的25000条评论和用于测试的250 ...

  3. Python深度学习案例1--电影评论分类(二分类问题)

    我觉得把课本上的案例先自己抄一遍,然后将书看一遍.最后再写一篇博客记录自己所学过程的感悟.虽然与课本有很多相似之处.但自己写一遍感悟会更深 电影评论分类(二分类问题) 本节使用的是IMDB数据集,使用 ...

  4. Python深度学习读书笔记-6.二分类问题

    电影评论分类:二分类问题   加载 IMDB 数据集 from keras.datasets import imdb (train_data, train_labels), (test_data, t ...

  5. 二分类问题 - 【老鱼学tensorflow2】

    什么是二分类问题? 二分类问题就是最终的结果只有好或坏这样的一个输出. 比如,这是好的,那是坏的.这个就是二分类的问题. 我们以一个电影评论作为例子来进行.我们对某部电影评论的文字内容为好评和差评. ...

  6. 二分类问题续 - 【老鱼学tensorflow2】

    前面我们针对电影评论编写了二分类问题的解决方案. 这里对前面的这个方案进行一些改进. 分批训练 model.fit(x_train, y_train, epochs=20, batch_size=51 ...

  7. keras框架下的深度学习(二)二分类和多分类问题

    本文第一部分是对数据处理中one-hot编码的讲解,第二部分是对二分类模型的代码讲解,其模型的建立以及训练过程与上篇文章一样:在最后我们将训练好的模型保存下来,再用自己的数据放入保存下来的模型中进行分 ...

  8. 【原】Spark之机器学习(Python版)(二)——分类

    写这个系列是因为最近公司在搞技术分享,学习Spark,我的任务是讲PySpark的应用,因为我主要用Python,结合Spark,就讲PySpark了.然而我在学习的过程中发现,PySpark很鸡肋( ...

  9. Kaggle实战之二分类问题

    0. 前言 1. MNIST 数据集 2. 二分类器 3. 效果评测 4. 多分类器与误差分析 5. Kaggle 实战 0. 前言 "尽管新技术新算法层出不穷,但是掌握好基础算法就能解决手 ...

  10. 准确率(Accuracy), 精确率(Precision), 召回率(Recall)和F1-Measure(对于二分类问题)

    首先我们可以计算准确率(accuracy),其定义是: 对于给定的测试数据集,分类器正确分类的样本数与总样本数之比.也就是损失函数是0-1损失时测试数据集上的准确率. 下面在介绍时使用一下例子: 一个 ...

随机推荐

  1. django.core.exceptions.ImproperlyConfigured: Error loading MySQLdb module: No module named 'MySQLdb'. Did you install mysqlclient or MySQL-python?

    Error msg: Unhandled exception in thread started by <function check_errors.<locals>.wrapper ...

  2. AtCoder Grand Contest 032-B - Balanced Neighbors (构造)

    Time Limit: 2 sec / Memory Limit: 1024 MB Score : 700700 points Problem Statement You are given an i ...

  3. Arguments Optional 计算两个参数之和的 function

    创建一个计算两个参数之和的 function.如果只有一个参数,则返回一个 function,该 function 请求一个参数然后返回求和的结果. 例如,add(2, 3) 应该返回 5,而 add ...

  4. Django(七)缓存、信号、Form

    大纲 一.缓存 1.1.五种缓存配置 1.2配置 2.1.三种应用(全局.视图函数.模板) 2.2 应用多个缓存时生效的优先级 二.信号 1.Django内置信号 2.自定义信号 三.Form 1.初 ...

  5. WebSocket群聊与单聊

    一 . WebSocket实现群聊 py文件代码 # py文件 from flask import Flask, render_template, request from geventwebsock ...

  6. 利用Python查看微信共同好友

    思路 首先通过itchat这个微信个人号接口扫码登录个人微信网页版,获取可以识别好友身份的数据.这里是需要分别登录两人微信的,拿到两人各自的好友信息存到列表中. 这样一来,查共同好友就转化成了查两个列 ...

  7. Spring Boot(二):数据库操作

    本文主要讲解如何通过spring boot来访问数据库,本文会演示三种方式来访问数据库,第一种是JdbcTemplate,第二种是JPA,第三种是Mybatis.之前已经提到过,本系列会以一个博客系统 ...

  8. delphi中响应鼠标进入或离开控件的方法

    Delphi没有MouseEnter与MouseLeave的事件,网上说可以响应CM_MOUSEENTER和CM_MOUSELEAVE消息来实现.这两个消息是VCL自己定义的消息,看了Delphi的C ...

  9. BZOJ5507 GXOI/GZOI2019旧词 (树链剖分+线段树)

    https://www.cnblogs.com/Gloid/p/9412357.html差分一下是一样的问题.感觉几年没写过树剖了. #include<iostream> #include ...

  10. 用CNN对CIFAR10进行分类(pytorch)

    CIFAR10有60000个\(32*32\)大小的有颜色的图像,一共10种类别,每种类别有6000个. 训练集一共50000个图像,测试集一共10000个图像. 先载入数据集 import nump ...