Keras实现MNIST分类
仅仅为了学习Keras的使用,使用一个四层的全连接网络对MNIST数据集进行分类,网络模型各层结点数为:784: 256: 128 : 10;
使用整体数据集的75%作为训练集,25%作为测试集,最终在测试集上的正确率也就只能达到92%,太低了:
precision recall f1-score support
0.0 0.95 0.96 0.96 1721
1.0 0.95 0.97 0.96 1983
2.0 0.91 0.90 0.91 1793
3.0 0.91 0.88 0.89 1833
4.0 0.92 0.93 0.92 1689
5.0 0.87 0.86 0.87 1598
6.0 0.92 0.95 0.94 1699
7.0 0.94 0.93 0.93 1817
8.0 0.89 0.87 0.88 1721
9.0 0.89 0.90 0.89 1646
micro avg 0.92 0.92 0.92 17500
macro avg 0.91 0.92 0.91 17500
weighted avg 0.92 0.92 0.92 17500
训练过程中,损失和正确率曲线:

上面使用的优化方法是SGD,下面在保持所有参数不变的情况下,使用RMSpro进行优化,最后的结果看起来好了不少啊达到98%:
precision recall f1-score support
0.0 0.99 0.99 0.99 1719
1.0 0.99 0.99 0.99 1997
2.0 0.99 0.97 0.98 1785
3.0 0.97 0.98 0.97 1746
4.0 0.98 0.98 0.98 1654
5.0 0.97 0.97 0.97 1556
6.0 0.99 0.99 0.99 1740
7.0 0.98 0.97 0.98 1845
8.0 0.97 0.98 0.97 1669
9.0 0.96 0.97 0.97 1789
micro avg 0.98 0.98 0.98 17500
macro avg 0.98 0.98 0.98 17500
weighted avg 0.98 0.98 0.98 17500

这相差有些大了啊,损失一开始就比较低,正确率就比较高;
在执行一些epoch之后,训练集损失降低到一定程度并趋于平缓,而测试集的损失却在逐渐升高,看来是过拟合了啊。
也就是说大概在最初几个epoch的时候训练集的损失比较低,而且测试集的损失达到最低,正确率也都不错;应该就可以停止了。
代码:
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Time : 19-5-9
"""
implement simple neural networks using the Keras libary
"""
__author__ = 'Zhen Chen'
# import the necessary packages
from sklearn.preprocessing import LabelBinarizer
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from keras.models import Sequential
from keras.layers.core import Dense
from keras.optimizers import SGD
from sklearn import datasets
import matplotlib.pyplot as plt
import numpy as np
import argparse
# construct the argument parse and parse the arguments
parser = argparse.ArgumentParser()
parser.add_argument("-o", "--output", default="./Training Loss and Accuracy.png",
help="path to the output loss/accuracy plot")
args = parser.parse_args()
# grab the MNIST datset (if this is your first time running this
# script, the download may take a minute -- the 55MB MNIST datset
# will be downloaded
print("[INFO] loading MNIST (full) dataset...")
dataset = datasets.fetch_mldata("MNIST Original")
# scale the raw pixel intensities to the range [0, 1.0], then
# construct the training and testing splits
data = dataset.data.astype("float") / 255.0
(trainX, testX, trainY, testY) = train_test_split(data, dataset.target, test_size=0.25)
# convert the labels form integers to vectors
lb = LabelBinarizer()
trainY = lb.fit_transform(trainY)
testY = lb.fit_transform(testY)
# define the 784-256-128-10 architecture using Keras
model = Sequential()
model.add(Dense(256, input_shape=(784,), activation="sigmoid"))
model.add(Dense(128, activation="sigmoid"))
model.add(Dense(10, activation="softmax"))
# train the model using SGD
print("[INFO] training network...")
sgd = SGD(0.01)
model.compile(loss="categorical_crossentropy", optimizer=sgd, metrics=["accuracy"])
H = model.fit(trainX, trainY, validation_data=(testX, testY), epochs=100, batch_size=128)
# evaluate the network
print("[INFO] evaluating network...")
predictions = model.predict(testX, batch_size=128)
print(classification_report(testY.argmax(axis=1),
predictions.argmax(axis=1),
target_names=[str(x) for x in lb.classes_]))
# plot the training loss and accuracy
plt.style.use("ggplot")
plt.figure()
plt.plot(np.arange(0, 100), H.history["loss"], label="train_loss")
plt.plot(np.arange(0, 100), H.history["val_loss"], label="val_loss")
plt.plot(np.arange(0, 100), H.history["acc"], label="train_acc")
plt.plot(np.arange(0, 100), H.history["val_acc"], label="val_loss")
plt.title("Training Loss and Accuracy")
plt.xlabel("Epoch #")
plt.ylabel("Loss/Accuracy")
plt.legend()
plt.savefig(args.output)
Keras实现MNIST分类的更多相关文章
- 芝麻HTTP:TensorFlow LSTM MNIST分类
本节来介绍一下使用 RNN 的 LSTM 来做 MNIST 分类的方法,RNN 相比 CNN 来说,速度可能会慢,但可以节省更多的内存空间. 初始化 首先我们可以先初始化一些变量,如学习率.节点单元数 ...
- TensorFlow入门(三)多层 CNNs 实现 mnist分类
欢迎转载,但请务必注明原文出处及作者信息. 深入MNIST refer: http://wiki.jikexueyuan.com/project/tensorflow-zh/tutorials/mni ...
- keras实现mnist数据集手写数字识别
一. Tensorflow环境的安装 这里我们只讲CPU版本,使用 Anaconda 进行安装 a.首先我们要安装 Anaconda 链接:https://pan.baidu.com/s/1AxdGi ...
- 用Pytorch训练MNIST分类模型
本次分类问题使用的数据集是MNIST,每个图像的大小为\(28*28\). 编写代码的步骤如下 载入数据集,分别为训练集和测试集 让数据集可以迭代 定义模型,定义损失函数,训练模型 代码 import ...
- 利用CNN神经网络实现手写数字mnist分类
题目: 1)In the first step, apply the Convolution Neural Network method to perform the training on one ...
- keras如何求分类问题中的准确率和召回率
https://www.zhihu.com/question/53294625 由于要用keras做一个多分类的问题,评价标准采用precision,recall,和f1_score:但是keras中 ...
- 深度学习原理与框架-Tensorflow卷积神经网络-卷积神经网络mnist分类 1.tf.nn.conv2d(卷积操作) 2.tf.nn.max_pool(最大池化操作) 3.tf.nn.dropout(执行dropout操作) 4.tf.nn.softmax_cross_entropy_with_logits(交叉熵损失) 5.tf.truncated_normal(两个标准差内的正态分布)
1. tf.nn.conv2d(x, w, strides=[1, 1, 1, 1], padding='SAME') # 对数据进行卷积操作 参数说明:x表示输入数据,w表示卷积核, stride ...
- 深度学习原理与框架-Tensorflow卷积神经网络-神经网络mnist分类
使用tensorflow构造神经网络用来进行mnist数据集的分类 相比与上一节讲到的逻辑回归,神经网络比逻辑回归多了隐藏层,同时在每一个线性变化后添加了relu作为激活函数, 神经网络使用的损失值为 ...
- 深度学习之神经网络核心原理与算法-caffe&keras框架图片分类
之前我们在使用cnn做图片分类的时候使用了CIFAR-10数据集 其他框架对于CIFAR-10的图片分类是怎么做的 来与TensorFlow做对比. Caffe Keras 安装 官方安装文档: ht ...
随机推荐
- BZOJ4944: [Noi2017]泳池
BZOJ4944: [Noi2017]泳池 题目背景 久莲是个爱玩的女孩子. 暑假终于到了,久莲决定请她的朋友们来游泳,她打算先在她家的私人海滩外圈一块长方形的海域作为游泳场. 然而大海里有着各种各样 ...
- 取得微信用户OpenID
公司需要微信这个平台和用户交流,于是开始研究微信公众平台.微信公众平台分为两种模式,其一是编辑模式,比如用户发什么内容,你可以响应什么内容.另外一种便是开发模式,这个模式功能丰富,不仅仅可以获取到用户 ...
- ABAP 调用webservice 错误
错误:1.soamanager 配置端口错误: 调整端口后报错: java端回复: 嗯 有问题了我待会儿看看应该是数据有问题
- mini2440 u-boot禁止蜂鸣器
mini2440的u-boot版本启动之后马上就会开启蜂鸣器,在办公环境下有可能会影响同事的工作,所以我考虑将其禁止掉. 我使用的mini2440使用的光盘是2013年10月的版本,我在该光盘下的u- ...
- js正则表达式(2)
找到以某个字符串开头的字符串 var myReg=/^(abc)/gim; 如果不加m,那么只找一行,而加了m可以找到每行中以该字符串开头的匹配文本. 如: abcsfsdfasd7890hklfah ...
- 京东面试题 Java相关
1.JVM的内存结构和管理机制: JVM实例:一个独立运行的java程序,是进程级别 JVM执行引擎:用户运行程序的线程,是JVM实例的一部分 JVM实例的诞生 当启动一个java程序时.一个JVM实 ...
- 前端多媒体(4)—— video标签全面分析
测试地址:https://young-cowboy.github.io/gallery/html5_video/index.html 属性 一些属性是只读的,一些属性是可以修改从而影响视频播放的. a ...
- oracle 11g 常用命令
sqlplus system/123@ORCL; 查看oracle字符集: select * from nls_database_parameters where parameter ='NLS_CH ...
- bzoj 4503 两个串 快速傅里叶变换FFT
题目大意: 给定两个\((length \leq 10^5)\)的字符串,问第二个串在第一个串中出现了多少次.并且第二个串中含有单字符通配符. 题解: 首先我们从kmp的角度去考虑 这道题从字符串数据 ...
- AtCoder Regular Contest 073 E:Ball Coloring
题目传送门:https://arc073.contest.atcoder.jp/tasks/arc073_c 题目翻译 给你\(N\)个袋子,每个袋子里有俩白球,白球上写了数字.对于每一个袋子,你需要 ...