keras-基于CNN网络的Mnist数据集分类

1.数据的载入和预处理

import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import *
from keras.optimizers import SGD,Adam
from keras.regularizers import l2
from keras.utils.vis_utils import plot_model
from matplotlib import pyplot as plt import os import tensorflow as tf # 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data() # 预处理
# 将(60000,28,28)转化为(-1,28,28,1),最后1是图片深度 x_train = x_train.reshape(-1,28,28,1)/255.0
x_test= x_test.reshape(-1,28,28,1)/255.0
# 将输出转化为one_hot编码
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)

2.创建网络打印训练数据# 创建网络

model = Sequential([

    # 创建卷积层提取特征,对卷积核进行正则化
Conv2D(input_shape=(28,28,1),filters=32,kernel_size=5,strides=1,padding='same',activation='relu',kernel_regularizer=l2(0.01)),
   # 池化层,对特征的筛选
   MaxPool2D(pool_size=(2,2),strides=2,padding='same'), Flatten(),
  
   Dense(units=128,input_dim=784,bias_initializer='one',activation='tanh'),
    Dropout(0.2),
Dense(units=10,bias_initializer='one',activation='softmax')
]) # 编译
# 自定义优化器
sgd = SGD(lr=0.1)
adma = Adam(lr=0.001)
# 运用交叉熵
model.compile(optimizer=adma,
loss='categorical_crossentropy',
# 得到训练过程中的准确率
metrics=['accuracy']) model.fit(x_train,y_train,batch_size=32,epochs=10,validation_split=0.2) # 评估模型
loss,acc = model.evaluate(x_test,y_test,)
print('\ntest loss',loss)
print('test acc',acc)

out:

Epoch 1/10

32/48000 [..............................] - ETA: 10:18 - loss: 2.7563 - acc: 0.1562
96/48000 [..............................] - ETA: 3:53 - loss: 2.6141 - acc: 0.1354

......

......

Epoch 10/10

45952/48000 [===========================>..] - ETA: 0s - loss: 0.0664 - acc: 0.9905
47616/48000 [============================>.] - ETA: 0s - loss: 0.0663 - acc: 0.9908
48000/48000 [==============================] - 2s 37us/step - loss: 0.0663 - acc: 0.9910 - val_loss: 0.0149 - val_acc: 0.9884

32/10000 [..............................] - ETA: 4s
3360/10000 [=========>....................] - ETA: 0s
5824/10000 [================>.............] - ETA: 0s
8512/10000 [========================>.....] - ETA: 0s
10000/10000 [==============================] - 0s 20us/step

test loss 0.015059704356454312
test acc 0.988

6.keras-基于CNN网络的Mnist数据集分类的更多相关文章

  1. 基于CNN网络的汉字图像字体识别及其原理

    现代办公要将纸质文档转换为电子文档的需求越来越多,目前针对这种应用场景的系统为OCR系统,也就是光学字符识别系统,例如对于古老出版物的数字化.但是目前OCR系统主要针对文字的识别上,对于出版物的版面以 ...

  2. 3.keras-简单实现Mnist数据集分类

    keras-简单实现Mnist数据集分类 1.载入数据以及预处理 import numpy as np from keras.datasets import mnist from keras.util ...

  3. 机器学习与Tensorflow(3)—— 机器学习及MNIST数据集分类优化

    一.二次代价函数 1. 形式: 其中,C为代价函数,X表示样本,Y表示实际值,a表示输出值,n为样本总数 2. 利用梯度下降法调整权值参数大小,推导过程如下图所示: 根据结果可得,权重w和偏置b的梯度 ...

  4. 数据挖掘入门系列教程(十二)之使用keras构建CNN网络识别CIFAR10

    简介 在上一篇博客:数据挖掘入门系列教程(十一点五)之CNN网络介绍中,介绍了CNN的工作原理和工作流程,在这一篇博客,将具体的使用代码来说明如何使用keras构建一个CNN网络来对CIFAR-10数 ...

  5. 【TensorFlow/简单网络】MNIST数据集-softmax、全连接神经网络,卷积神经网络模型

    初学tensorflow,参考了以下几篇博客: soft模型 tensorflow构建全连接神经网络 tensorflow构建卷积神经网络 tensorflow构建卷积神经网络 tensorflow构 ...

  6. 深度学习(一)之MNIST数据集分类

    任务目标 对MNIST手写数字数据集进行训练和评估,最终使得模型能够在测试集上达到\(98\%\)的正确率.(最终本文达到了\(99.36\%\)) 使用的库的版本: python:3.8.12 py ...

  7. CNN算法解决MNIST数据集识别问题

    网络实现程序如下 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data # 用于设置将记 ...

  8. keras基于卷积网络手写数字识别

    import time import keras from keras.utils import np_utils start = time.time() (x_train, y_train), (x ...

  9. 神经网络MNIST数据集分类tensorboard

    今天分享同样数据集的CNN处理方式,同时加上tensorboard,可以看到清晰的结构图,迭代1000次acc收敛到0.992 先放代码,注释比较详细,变量名字看单词就能知道啥意思 import te ...

随机推荐

  1. Analysis分析器

    一.Analysis简介 场景执行过程中,loadrunner收集执行过程中的数据,存储在扩展名为.lrr的文件中,Analysis分析器打开这个文件,对文件信息进行处理,并生成图和报告. 数据分析不 ...

  2. python selenium unittest Fixture(setUp/tearDown)笔记

    Fixture用途: 1.做测试前后的初始化设置,如测试数据准备,链接数据库,打开浏览器等这些操作都可以使用fixture来实现 2.测试用例的前置条件可以使用fixture实现 Fixture使用: ...

  3. 浅谈PostgreSQL用户权限

    问题 经常在PG群里看到有人在问“为什么我对表赋予了权限:但是还是不能访问表” 解析 若你看懂德哥这篇文章PostgreSQL逻辑结构和权限体系介绍:上面对你就不是困扰你的问题 解决这个问题很简单:在 ...

  4. JavaScript和TypeScript的区别和联系

    转载自:http://web.jobbole.com/93618/?utm_source=group.jobbole.com&utm_medium=relatedArticles JavaSc ...

  5. Spring+Struts2+Hibernate框架搭建

    SSH框架版本:Struts-2.3.30  +  Spring-4.2.2  +  Hibernate5.2.2 下图是所需要的Jar包: 下面是项目的结构图: 1.web.xml <?xml ...

  6. Spring 基于 Java 的配置

    前面已经学习如何使用 XML 配置文件来配置 Spring bean. 基于 Java 的配置可以达到基于XML配置的相同效果. 基于 Java 的配置选项,可以使你在不用配置 XML 的情况下编写大 ...

  7. JUC整理笔记一之细说Unsafe

    JUC(java.util.concurrent)的开始,可以说是从Unsafe类开始. Unsafe 简介 Unsafe在sun.misc 下,顾名思义,这是一个不安全的类,因为Unsafe类所操作 ...

  8. IIS 报 :HTTP Error 503. The service is unavailable.

    打开IIS 找到你对应的网站名称然后你会发现应用池停止了 点击你对应的网站右键点击启动既可

  9. redis配置文件.conf和常用配置

    1,配置文件在哪 2,Units单位 1 配置大小单位,开头定义了一些基本的度量单位,只支持bytes,不支持bit 2 对大小写不敏感 3,INCLUDES包含 和我们的spring配置文件类似,可 ...

  10. 特效 css3 渐变背景框

    .box{ 子级 position: relative; width: 300px; height: 400px; display: flex; justify-content: center; al ...